Table of Contents
Fetching ...

Two failure modes of deep transformers and how to avoid them: a unified theory of signal propagation at initialisation

Alessio Giorlandino, Sebastian Goldt

TL;DR

The paper addresses how to initialise deep Transformers to avoid information loss during forward and backward passes. By mapping self-attention to a Random Energy Model, it derives forward and backward propagation equations and identifies a critical initialisation scale eta_c that separates rank collapse from entropy collapse, yielding trainability diagrams that predict viable hyperparameters. It combines self-attention with skip connections, LayerNorm, and MLPs to provide a unified, quantitative framework and validates it with case studies on standard BERT-like architectures and LayerNorm placements. The results offer principled initialization guidelines and motivate gain-controlled attention as a way to avoid both collapse modes, enabling reliable training at scale.

Abstract

Finding the right initialisation for neural networks is crucial to ensure smooth training and good performance. In transformers, the wrong initialisation can lead to one of two failure modes of self-attention layers: rank collapse, where all tokens collapse into similar representations, and entropy collapse, where highly concentrated attention scores lead to training instability. While previous work has studied different scaling regimes for transformers, an asymptotically exact, down-to-the constant prescription for how to initialise transformers has so far been lacking. Here, we provide an analytical theory of signal propagation through deep transformers with self-attention, layer normalisation, skip connections and MLP. Our theory yields a simple algorithm to compute trainability diagrams that identify the correct choice of initialisation hyper-parameters for a given architecture. We overcome the key challenge, an exact treatment of the self-attention layer, by establishing a formal parallel with the Random Energy Model from statistical physics. We also analyse gradients in the backward path and determine the regime where gradients vanish at initialisation. We demonstrate the versatility of our framework through three case studies. Our theoretical framework gives a unified perspective on the two failure modes of self-attention and gives quantitative predictions on the scale of both weights and residual connections that guarantee smooth training.

Two failure modes of deep transformers and how to avoid them: a unified theory of signal propagation at initialisation

TL;DR

The paper addresses how to initialise deep Transformers to avoid information loss during forward and backward passes. By mapping self-attention to a Random Energy Model, it derives forward and backward propagation equations and identifies a critical initialisation scale eta_c that separates rank collapse from entropy collapse, yielding trainability diagrams that predict viable hyperparameters. It combines self-attention with skip connections, LayerNorm, and MLPs to provide a unified, quantitative framework and validates it with case studies on standard BERT-like architectures and LayerNorm placements. The results offer principled initialization guidelines and motivate gain-controlled attention as a way to avoid both collapse modes, enabling reliable training at scale.

Abstract

Finding the right initialisation for neural networks is crucial to ensure smooth training and good performance. In transformers, the wrong initialisation can lead to one of two failure modes of self-attention layers: rank collapse, where all tokens collapse into similar representations, and entropy collapse, where highly concentrated attention scores lead to training instability. While previous work has studied different scaling regimes for transformers, an asymptotically exact, down-to-the constant prescription for how to initialise transformers has so far been lacking. Here, we provide an analytical theory of signal propagation through deep transformers with self-attention, layer normalisation, skip connections and MLP. Our theory yields a simple algorithm to compute trainability diagrams that identify the correct choice of initialisation hyper-parameters for a given architecture. We overcome the key challenge, an exact treatment of the self-attention layer, by establishing a formal parallel with the Random Energy Model from statistical physics. We also analyse gradients in the backward path and determine the regime where gradients vanish at initialisation. We demonstrate the versatility of our framework through three case studies. Our theoretical framework gives a unified perspective on the two failure modes of self-attention and gives quantitative predictions on the scale of both weights and residual connections that guarantee smooth training.

Paper Structure

This paper contains 43 sections, 85 equations, 12 figures, 2 algorithms.

Figures (12)

  • Figure 1: Two failure modes of Transformers at initialisation, and how to avoid them.(a) Rank collapse occurs when the self-attention layer attends uniformly to all tokens, mapping all input tokens into the same output token. (b) Entropy collapse is a regime of highly saturated attention matrices which attend to random, semantically meaningless patterns, leading to training instability zhai2023stabilizing. (c) Trainability diagram for a 60-layer BERT Transformer, obtained from our analytical theory of signal propagation, see \ref{['alg:postnorm']}. Depending on the strength of the self-attention residual connections $\alpha_{\mathrm{SA}}$ (\ref{['eq:alpha']}) and the scale of initial key and query weights $\beta$ (\ref{['eq:init-of-QK']}), we delineate the three regimes of rank collapse, entropy collapse, and the regime where the Transformer is trainable (blue). (d) Average cosine similarity between token embeddings of a sequence taken from the TinyStories dataset as it propagates through the layers of a vanilla BERT model for different self-attention residual strengths; empirical measurements (dots) closely follow theoretical predictions (solid lines). Sufficiently large residual connections $\alpha_{\text{SA}}$ are key to preventing the similarity between tokens from becoming unity, which would indicate rank collapse. (e) Test loss of a 60-layer BERT model on TinyStories for two initialisations from each regime. Models suffering from rank or entropy collapse at initialisation fail to train, as predicted by theory. Full experimental details in \ref{['app:figure1']}.
  • Figure 2: Phase diagram for a single layer of self-attention. We use \ref{['result1']} to plot the average cosine similarity between pairs of tokens after one layer of self-attention as a function of the query/key variance parameter $\beta$ and the input average cosine similarity $\rho$. (Left): Theoretical phase diagram obtained from \ref{['result1']} (with $q = 1$ and $p = \rho$). For $\beta < \beta_c$, we observe a rank collapse phase, where all input tokens map to a single output direction and the cosine similarity saturates at 1. For $\beta > \beta_c$, token diversity is preserved, but entropy collapse emerges. (Right): Simulations with embedding dimension $d = 512$ and sequence length $T = 1024$ qualitatively reproduce the theoretical transition, with deviations attributed to finite-size effects, as discussed in \ref{['sec:result1']}.
  • Figure 3: (a, b)A phase transition in the impact of query / key initialisation on training dynamics. Average Shannon entropy of attention's row and the test loss of a Transformer with a single layer of self-attention trained on masked language modelling on TinyStories as we vary the scale of the initialisation from small to large initial weights (blue to red). Small initial weights (blue) permit attention to diversify over time, supporting effective learning, while large-variance initialisation (red) collapses the attention to only a few tokens, visible in an entropy that quickly goes near zero. Here $\beta_c(\rho=0) = \sqrt{2}$. (c)Norm of the query gradient. Frobenius norm of the gradient of the loss with respect to query weights for various combinations of sequence length $T=2048, 4096, 8192$ and embedding dimension $d=256, 512, 1024$. As predicted by \ref{['result2']}, gradients collapse for different $T$ and $d$, and vanishing gradients afflict the low-$\beta$ regime.
  • Figure 4: Theoretical prediction of the evolution with depth of the average cosine similarity for the standard Transformer and the Gain-controlled Transformer under both LN strategies. Rank collapse is avoided simply by removing the mean value in the self-attention layer. Here, we set $\alpha_{\text{SA}} = \alpha_{\text{MLP}} = 1$.
  • Figure 5: Theory and experiments ($T=10^5$) comparison of the computation of $Y^{(2)}(\beta)$, finite size effects are visible around the phase transition.
  • ...and 7 more figures