Table of Contents
Fetching ...

Unifying back-propagation and forward-forward algorithms through model predictive control

Lianhai Ren, Qianxiao Li

TL;DR

This work proposes a principled method to choose the optimization horizon based on given objectives and model specifications, and performs a precise analysis of this trade-off on a deep linear network.

Abstract

We introduce a Model Predictive Control (MPC) framework for training deep neural networks, systematically unifying the Back-Propagation (BP) and Forward-Forward (FF) algorithms. At the same time, it gives rise to a range of intermediate training algorithms with varying look-forward horizons, leading to a performance-efficiency trade-off. We perform a precise analysis of this trade-off on a deep linear network, where the qualitative conclusions carry over to general networks. Based on our analysis, we propose a principled method to choose the optimization horizon based on given objectives and model specifications. Numerical results on various models and tasks demonstrate the versatility of our method.

Unifying back-propagation and forward-forward algorithms through model predictive control

TL;DR

This work proposes a principled method to choose the optimization horizon based on given objectives and model specifications, and performs a precise analysis of this trade-off on a deep linear network.

Abstract

We introduce a Model Predictive Control (MPC) framework for training deep neural networks, systematically unifying the Back-Propagation (BP) and Forward-Forward (FF) algorithms. At the same time, it gives rise to a range of intermediate training algorithms with varying look-forward horizons, leading to a performance-efficiency trade-off. We perform a precise analysis of this trade-off on a deep linear network, where the qualitative conclusions carry over to general networks. Based on our analysis, we propose a principled method to choose the optimization horizon based on given objectives and model specifications. Numerical results on various models and tasks demonstrate the versatility of our method.
Paper Structure (45 sections, 8 theorems, 73 equations, 9 figures, 3 tables, 1 algorithm)

This paper contains 45 sections, 8 theorems, 73 equations, 9 figures, 3 tables, 1 algorithm.

Key Result

Theorem 3.4

Let $W(t)=I+\frac{1}{T}\tilde{W}(t)$, $\{\tilde{W}(t)\}$ are matrices with bounded 2-norm, i.e. $\exists c>0$ such that $\|\tilde{W}(t)\|_2\leq c,\forall t$. Denote $\theta_h$ the angle between $g_h$ and $g_T$. When $T\to\infty,h\to\infty$, $\frac{h}{T}=\alpha$, $1-\cos^2(\theta_{h})= O((1-\frac{h}{

Figures (9)

  • Figure 1: Diagram of MPC framework on a 4-block model: black arrows denote the forward pass and red arrows denote the backward pass, $\nabla_t\triangleq g_h(u(t))$ is the gradient of $t$-th block. MPC uses partial gradient propagation. We can see that BP can be seen as MPC with the full horizon ($h=T$), while FF is MPC with horizon 1 ($h=1$)
  • Figure 2: Relationship Between $g_h$ and $h$ on different models. The x-axis shows $T-h$ and y-axis shows $1-\cos(\theta_h)$. Each line represents a different training epoch. Left: Linear residual NN, Middle: Residual MLP, Right: ResNet-62
  • Figure 3: Test accuracy and memory usage of full tuning ResNet-50 and LoRA tuning ViT-b16 on CIFAR100. Dark line shows the loss of final epoch and shallow bars shows the memory usage of the horizon. The maximum Left: ResNet-50, Right: ViT-b16
  • Figure 4: Detailed snapshots in the training procedure of the full tuning ResNet-50 and LoRA tuning ViT-b16 with different horizons. Each line denotes the training loss of one recording epoch. Left: ResNet-50, Right: ViT-b16
  • Figure 5: Relative performance $Rel(a)$ of different algorithms on various models, costs, and objectives. Up: Heatmap of relative objective value of baselines and horizon selection algorithms and objectives, y-axis: different algorithms, x-axis: models, tasks, and objectives (refer to Table \ref{['tab:notation']}), color: relative objective value, cross notation '$\times$': infeasible solution. Bottom: Average relative performance of different algorithms on different models (the infeasible solution is treated as 1.5)
  • ...and 4 more figures

Theorems & Definitions (13)

  • Remark 3.1
  • Remark 3.2
  • Remark 3.3
  • Theorem 3.4: (Informal) Gradient Deviation in Deep Linear Network
  • Theorem \ref{thm:gradient estimate} (Formal) Gradient Deviation in Deep Linear Network
  • Lemma \ref{thm:gradient estimate} (Formal) Gradient Deviation in Deep Linear Network
  • Lemma \ref{thm:gradient estimate} (Formal) Gradient Deviation in Deep Linear Network
  • Lemma \ref{thm:gradient estimate} (Formal) Gradient Deviation in Deep Linear Network
  • Lemma \ref{thm:gradient estimate} (Formal) Gradient Deviation in Deep Linear Network
  • Lemma \ref{thm:gradient estimate} (Formal) Gradient Deviation in Deep Linear Network
  • ...and 3 more