Table of Contents
Fetching ...

Stability of Transformers under Layer Normalization

Kelvin Kan, Xingjian Li, Benjamin J. Zhang, Tuhin Sahai, Stanley Osher, Krishna Kumar, Markos A. Katsoulakis

TL;DR

This work investigates why Transformer training can be unstable and how Layer Normalization (LN) placement affects both forward activations and backward gradients. By casting Transformer training as a continuous-time mean-field control problem, it shows that Pre-LN leads to unbounded hidden-state growth and ill-defined optimality, while Peri-LN constrains growth to linear/quadratic bounds and yields gradient stability. The authors introduce a residual-step scaling with factor $\Delta t<1$ that further improves both forward and backward stability, and they validate these insights with GPT-2 level experiments, demonstrating robust stability and competitive performance with Peri-LN. The framework offers a principled workflow to screen architectural variants before expensive training and provides principled generalization bounds via Wasserstein-based uncertainty quantification.

Abstract

Despite their widespread use, training deep Transformers can be unstable. Layer normalization, a standard component, improves training stability, but its placement has often been ad-hoc. In this paper, we conduct a principled study on the forward (hidden states) and backward (gradient) stability of Transformers under different layer normalization placements. Our theory provides key insights into the training dynamics: whether training drives Transformers toward regular solutions or pathological behaviors. For forward stability, we derive explicit bounds on the growth of hidden states in trained Transformers. For backward stability, we analyze how layer normalization affects the backpropagation of gradients, thereby explaining the training dynamics of each layer normalization placement. Our analysis also guides the scaling of residual steps in Transformer blocks, where appropriate choices can further improve stability and performance. Our numerical results corroborate our theoretical findings. Beyond these results, our framework provides a principled way to sanity-check the stability of Transformers under new architectural modifications, offering guidance for future designs.

Stability of Transformers under Layer Normalization

TL;DR

This work investigates why Transformer training can be unstable and how Layer Normalization (LN) placement affects both forward activations and backward gradients. By casting Transformer training as a continuous-time mean-field control problem, it shows that Pre-LN leads to unbounded hidden-state growth and ill-defined optimality, while Peri-LN constrains growth to linear/quadratic bounds and yields gradient stability. The authors introduce a residual-step scaling with factor that further improves both forward and backward stability, and they validate these insights with GPT-2 level experiments, demonstrating robust stability and competitive performance with Peri-LN. The framework offers a principled workflow to screen architectural variants before expensive training and provides principled generalization bounds via Wasserstein-based uncertainty quantification.

Abstract

Despite their widespread use, training deep Transformers can be unstable. Layer normalization, a standard component, improves training stability, but its placement has often been ad-hoc. In this paper, we conduct a principled study on the forward (hidden states) and backward (gradient) stability of Transformers under different layer normalization placements. Our theory provides key insights into the training dynamics: whether training drives Transformers toward regular solutions or pathological behaviors. For forward stability, we derive explicit bounds on the growth of hidden states in trained Transformers. For backward stability, we analyze how layer normalization affects the backpropagation of gradients, thereby explaining the training dynamics of each layer normalization placement. Our analysis also guides the scaling of residual steps in Transformer blocks, where appropriate choices can further improve stability and performance. Our numerical results corroborate our theoretical findings. Beyond these results, our framework provides a principled way to sanity-check the stability of Transformers under new architectural modifications, offering guidance for future designs.

Paper Structure

This paper contains 52 sections, 20 theorems, 113 equations, 1 figure, 2 tables.

Key Result

Lemma 1

The layer normalization output ${\bf z}=\text{LN}({\bf x}; {\boldsymbol \gamma}, {\boldsymbol \beta})$ lies on the ellipsoid where ${\boldsymbol \Gamma} = {\rm diag}({\boldsymbol \gamma})\in \mathbb{R}^{d\times d}$.

Figures (1)

  • Figure 1: Moments of hidden states across layers for the trained GPT-2 XL. With tuned weight decay for Pre-LN, the growth rate remains below the theoretical exponential upper bound (\ref{['eq:expo_growth_discrete']}); exponential growth is observed in, e.g., kim2025peri. The residual step scaling (see \ref{['sec:delta_t']}) effectively controls the growth at no extra cost.

Theorems & Definitions (30)

  • Lemma 1
  • Theorem 2
  • Theorem 3
  • Theorem 4: Controlled Growth of Entry-wise Moments
  • Theorem 5: Quadratic Growth of Data-wise Variance
  • Theorem 6
  • Proposition 7
  • Proposition 8
  • Lemma 8
  • proof
  • ...and 20 more