Two failure modes of deep transformers and how to avoid them: a unified theory of signal propagation at initialisation
Alessio Giorlandino, Sebastian Goldt
TL;DR
The paper addresses how to initialise deep Transformers to avoid information loss during forward and backward passes. By mapping self-attention to a Random Energy Model, it derives forward and backward propagation equations and identifies a critical initialisation scale eta_c that separates rank collapse from entropy collapse, yielding trainability diagrams that predict viable hyperparameters. It combines self-attention with skip connections, LayerNorm, and MLPs to provide a unified, quantitative framework and validates it with case studies on standard BERT-like architectures and LayerNorm placements. The results offer principled initialization guidelines and motivate gain-controlled attention as a way to avoid both collapse modes, enabling reliable training at scale.
Abstract
Finding the right initialisation for neural networks is crucial to ensure smooth training and good performance. In transformers, the wrong initialisation can lead to one of two failure modes of self-attention layers: rank collapse, where all tokens collapse into similar representations, and entropy collapse, where highly concentrated attention scores lead to training instability. While previous work has studied different scaling regimes for transformers, an asymptotically exact, down-to-the constant prescription for how to initialise transformers has so far been lacking. Here, we provide an analytical theory of signal propagation through deep transformers with self-attention, layer normalisation, skip connections and MLP. Our theory yields a simple algorithm to compute trainability diagrams that identify the correct choice of initialisation hyper-parameters for a given architecture. We overcome the key challenge, an exact treatment of the self-attention layer, by establishing a formal parallel with the Random Energy Model from statistical physics. We also analyse gradients in the backward path and determine the regime where gradients vanish at initialisation. We demonstrate the versatility of our framework through three case studies. Our theoretical framework gives a unified perspective on the two failure modes of self-attention and gives quantitative predictions on the scale of both weights and residual connections that guarantee smooth training.
