Table of Contents
Fetching ...

ResNets of All Shapes and Sizes: Convergence of Training Dynamics in the Large-scale Limit

Louis-Pierre Chaintron, Lénaïc Chizat, Javier Maass

Abstract

We establish convergence of the training dynamics of residual neural networks (ResNets) to their joint infinite depth L, hidden width M, and embedding dimension D limit. Specifically, we consider ResNets with two-layer perceptron blocks in the maximal local feature update (MLU) regime and prove that, after a bounded number of training steps, the error between the ResNet and its large-scale limit is O(1/L + sqrt(D/(L M)) + 1/sqrt(D)). This error rate is empirically tight when measured in embedding space. For a budget of P = Theta(L M D) parameters, this yields a convergence rate O(P^(-1/6)) for the scalings of (L, M, D) that minimize the bound. Our analysis exploits in an essential way the depth-two structure of residual blocks and applies formally to a broad class of state-of-the-art architectures, including Transformers with bounded key-query dimension. From a technical viewpoint, this work completes the program initiated in the companion paper [Chi25] where it is proved that for a fixed embedding dimension D, the training dynamics converges to a Mean ODE dynamics at rate O(1/L + sqrt(D)/sqrt(L M)). Here, we study the large-D limit of this Mean ODE model and establish convergence at rate O(1/sqrt(D)), yielding the above bound by a triangle inequality. To handle the rich probabilistic structure of the limit dynamics and obtain one of the first rigorous quantitative convergence for a DMFT-type limit, we combine the cavity method with propagation of chaos arguments at a functional level on so-called skeleton maps, which express the weight updates as functions of CLT-type sums from the past.

ResNets of All Shapes and Sizes: Convergence of Training Dynamics in the Large-scale Limit

Abstract

We establish convergence of the training dynamics of residual neural networks (ResNets) to their joint infinite depth L, hidden width M, and embedding dimension D limit. Specifically, we consider ResNets with two-layer perceptron blocks in the maximal local feature update (MLU) regime and prove that, after a bounded number of training steps, the error between the ResNet and its large-scale limit is O(1/L + sqrt(D/(L M)) + 1/sqrt(D)). This error rate is empirically tight when measured in embedding space. For a budget of P = Theta(L M D) parameters, this yields a convergence rate O(P^(-1/6)) for the scalings of (L, M, D) that minimize the bound. Our analysis exploits in an essential way the depth-two structure of residual blocks and applies formally to a broad class of state-of-the-art architectures, including Transformers with bounded key-query dimension. From a technical viewpoint, this work completes the program initiated in the companion paper [Chi25] where it is proved that for a fixed embedding dimension D, the training dynamics converges to a Mean ODE dynamics at rate O(1/L + sqrt(D)/sqrt(L M)). Here, we study the large-D limit of this Mean ODE model and establish convergence at rate O(1/sqrt(D)), yielding the above bound by a triangle inequality. To handle the rich probabilistic structure of the limit dynamics and obtain one of the first rigorous quantitative convergence for a DMFT-type limit, we combine the cavity method with propagation of chaos arguments at a functional level on so-called skeleton maps, which express the weight updates as functions of CLT-type sums from the past.
Paper Structure (66 sections, 44 theorems, 381 equations, 4 figures)

This paper contains 66 sections, 44 theorems, 381 equations, 4 figures.

Key Result

Theorem 1.1

Let $C_0>0$. Assume that $\rho^{(i)}$ (the $i$-th order derivative of $\rho$) is bounded for $i\in [1:5]$ and that $\nabla \mathrm{loss}_i$ is Lipschitz for $i\in [1:N]$. Consider the dynamics $\mathrm{(ClippedGD)}$ with $\eta_u,\eta_v,\sigma_u,\sigma_v,\sigma_{\mathrm{in}}, \sigma_{\mathrm{out}}\ge provided that the right-hand side is smaller than $c_2$.

Figures (4)

  • Figure 1: Comparison of the experimental convergence rate of the hidden representation $h^{L}_k$ with the theoretical upper-bound $\Vert [ \frac{\alpha\sqrt{D}}{\sqrt{ML}}, \frac{\beta}{\sqrt{D}}]\Vert_2$ from Theorem \ref{['thm:main_theorem_intro']} with $\alpha=0.8$ and $\beta=2.5$ manually adjusted to fit observations (plain lines). The y-axis shows RMS error (averaged over $5$ random repetitions and over the dataset) on the last hidden state $\mathtt{h}^L_k$ after $k=15$ GD steps.
  • Figure 2: Comparison of the experimental convergence rate of the output $\mathtt{y}_k$ with the conjectured rate $\Vert [ \frac{\alpha D}{ML}, \frac{\beta}{\sqrt{D}}]\Vert_2$ (which is smaller than the rate from Theorem \ref{['thm:main_theorem_intro']}) with coefficients $\alpha=0.15$ and $\beta=0.9$ manually adjusted to fit observations (plain lines). The y-axis shows RMS error (averaged over $5$ random repetitions and over the dataset) on the output $\mathtt{y}_k$ after $k=15$ GD steps.
  • Figure 3: Histograms for the coordinates of the forward pass in a finite linear Resnet ($L=200$, $M=D=1000$), compared to the limit model, at training iterations $k \in\{0,4,8\}$. The red lines represent the pdf of the Gaussian random variable $H_k(\ell/L)$. The right plot shows values trajectories across depth of a few embedding coordinates $d$, comparing them to the coupled samples of the limit stochastic process $\ell\mapsto H_k^d(\ell/L)$.
  • Figure 4: Forward Pass RMS error after $k=5$ GD steps, $\Delta_5^h$, in the linear case $\rho(s)=s$, between finite ResNets of variable size, and the true limit model. Dotted lines display the function $\|[0.67\sqrt{\frac{D}{ML}}, 0.44\frac{1}{\sqrt{D}}]\|_2$ obtained from fitting our error rate to the data. The error term $O(\frac{1}{L})$ is negligible in our regime. The y-axis shows RMS error (averaged over $60$ random repetitions with a single datapoint). In the last plot, we overlay the predicted scaling of $\Theta(P^{-1/6})$

Theorems & Definitions (91)

  • Theorem 1.1: Quantitative large-scale limit of ResNets
  • Remark 1: Training the embedding matrices
  • Definition 2.1: Finite-dimensional skeleton maps
  • Remark 2: Non-anticipative skeleton maps
  • Remark 3: Fixed-point formulation
  • Definition 2.2: Mean-Field Skeleton Maps
  • Definition 2.3: Skeleton vectors
  • Remark 4: Skeleton Maps for $N>1$
  • Remark 5: Linear mean-field structure
  • Theorem 2.4: Well-posedness
  • ...and 81 more