Table of Contents
Fetching ...

Effective Theory of Transformers at Initialization

Emily Dinan, Sho Yaida, Susan Zhang

TL;DR

The paper develops an effective theory of Transformers at initialization to characterize forward-backward signal propagation in wide, deep models built from stem, LN, MHSA, and MLP blocks. By deriving blockwise preactivation statistics and NTK dynamics, it yields width-dependent initialization and per-parameter learning-rate scalings for SGD and AdamW, ensuring order-one behavior of the neural tangent kernel during training. The authors validate these insights with practical experiments on Vision Transformers for ImageNet-1k and encoder-decoder Language Transformers on span denoising, observing improved stability and, in some cases, performance gains under NTK-guided scaling. Collectively, the work bridges theoretical criticality and practical training in large-scale Transformers, offering concrete guidelines to scale initialization and optimization with model width and depth.

Abstract

We perform an effective-theory analysis of forward-backward signal propagation in wide and deep Transformers, i.e., residual neural networks with multi-head self-attention blocks and multilayer perceptron blocks. This analysis suggests particular width scalings of initialization and training hyperparameters for these models. We then take up such suggestions, training Vision and Language Transformers in practical setups.

Effective Theory of Transformers at Initialization

TL;DR

The paper develops an effective theory of Transformers at initialization to characterize forward-backward signal propagation in wide, deep models built from stem, LN, MHSA, and MLP blocks. By deriving blockwise preactivation statistics and NTK dynamics, it yields width-dependent initialization and per-parameter learning-rate scalings for SGD and AdamW, ensuring order-one behavior of the neural tangent kernel during training. The authors validate these insights with practical experiments on Vision Transformers for ImageNet-1k and encoder-decoder Language Transformers on span denoising, observing improved stability and, in some cases, performance gains under NTK-guided scaling. Collectively, the work bridges theoretical criticality and practical training in large-scale Transformers, offering concrete guidelines to scale initialization and optimization with model width and depth.

Abstract

We perform an effective-theory analysis of forward-backward signal propagation in wide and deep Transformers, i.e., residual neural networks with multi-head self-attention blocks and multilayer perceptron blocks. This analysis suggests particular width scalings of initialization and training hyperparameters for these models. We then take up such suggestions, training Vision and Language Transformers in practical setups.
Paper Structure (36 sections, 121 equations, 5 figures, 2 tables)

This paper contains 36 sections, 121 equations, 5 figures, 2 tables.

Figures (5)

  • Figure 1: Hyperparameter searches for standard uniform (top left), neural-tangent (top right), hybrid neural-tangent--maximal-update (bottom left), and maximal-update (bottom right) scaling strategies for Vision Transformers trained by AdamW$_{(\beta_1,\beta_2,\epsilon)=(0.9,0.999,\text{1e-8})}$. For each, the top-one validation accuracy on the ImageNet-1k dataset is plotted as a function of training epochs. In the legend, we record the training hyperparameter pair $(\texttt{lr},\texttt{wd})$ [with the max top-one validation accuracy along each trajectory].
  • Figure 2: Comparison of the standard (black), neural-tangent (blue), hybrid neural-tangent--maximal-update (purple), and maximal-update (red) scaling strategies for Vision Transformers trained by AdamW$_{(\beta_1,\beta_2,\epsilon)=(0.9,0.999,\text{1e-8})}$. For each, the validation accuracy is plotted as a function of training epochs, for three different seeds (different whiteness). In the legend, we record the optimal training hyperparameter pair $(\texttt{lr}^{\star},\texttt{wd}^{\star})$ for each scaling strategy [with the max top-one validation accuracy along each trajectory].
  • Figure 3: Hyperparameter searches for standard uniform (left) and neural-tangent (right) scaling strategies for BART-large trained by AdamW$_{(\beta_1,\beta_2,\epsilon)=(0.9,0.998,\text{1e-6})}$. For each, the validation loss is plotted as a function of training updates. In the legend, we record the global learning rate $\texttt{lr}$ [with the minimum validation loss along each trajectory]. Note that (i) for both scaling strategies, the higher $\texttt{lr}$ runs (red) experienced gradient overflow, (ii) lowering $\texttt{lr}$ (going from black to blue) degrades the model performance more for the standard uniform scaling strategy than for the neural-tangent scaling strategy, and (iii) for the neural-tangent scaling strategy, the $\texttt{lr}=32.768$ run -- despite its "hiccup" at around 90,000 iterations -- caught up with the $\texttt{lr}=16.384$ run in the end.
  • Figure 4: Comparison of the standard uniform (black) and neural-tangent (blue) scaling strategies for BART-large trained by AdamW$_{(\beta_1,\beta_2,\epsilon)=(0.9,0.998,\text{1e-6})}$. For each, the validation loss is plotted as a function of training updates. In the legend, we record the selected training hyperparameter pair $(\texttt{lr},\texttt{wd})$ for each scaling strategy [with the minimum validation loss along each trajectory].
  • Figure 5: Comparison of the standard uniform (black) and neural-tangent (blue) scaling strategies for R2C2 trained by AdamW$_{(\beta_1,\beta_2,\epsilon)=(0.9,0.998,\text{1e-6})}$; for further comparison, we also include the similar runs for BART-large (dashed). For each, the validation loss is plotted as a function of training updates. In the legend, we record the zero-shotted training hyperparameter pair $(\texttt{lr},\texttt{wd})$ for each scaling strategy [with the minimum validation loss along each trajectory].