Table of Contents
Fetching ...

Geometric Dynamics of Signal Propagation Predict Trainability of Transformers

Aditya Cowsik, Tamra Nebabu, Xiao-Liang Qi, Surya Ganguli

TL;DR

The paper addresses how deep transformers propagate forward signals and backward gradients from random initialization. By modeling token geometry with the $n\times n$ dot-product matrix $C$ and treating token representations as a system of interacting particles, it derives deterministic layerwise update maps for forward propagation and, separately, gradient propagation, introducing two Lyapunov exponents: $\lambda_a$ for angular dynamics and $\lambda_g$ for gradient growth. The intersection of forward and backward phase boundaries (edge-of-chaos where $\lambda_a=0$ and vanishing/exploding gradients where $\lambda_g=0$) yields a simple, necessary-and-sufficient condition for trainability, allowing prediction of final test loss from initialization. The results provide a principled initialization scheme and a broad framework for analyzing deep architectures beyond pure attention, with practical implications for improving trainability in deep transformer models.

Abstract

We investigate forward signal propagation and gradient back propagation in deep, randomly initialized transformers, yielding simple necessary and sufficient conditions on initialization hyperparameters that ensure trainability of deep transformers. Our approach treats the evolution of the representations of $n$ tokens as they propagate through the transformer layers in terms of a discrete time dynamical system of $n$ interacting particles. We derive simple update equations for the evolving geometry of this particle system, starting from a permutation symmetric simplex. Our update equations show that without MLP layers, this system will collapse to a line, consistent with prior work on rank collapse in transformers. However, unlike prior work, our evolution equations can quantitatively track particle geometry in the additional presence of nonlinear MLP layers, and it reveals an order-chaos phase transition as a function of initialization hyperparameters, like the strength of attentional and MLP residual connections and weight variances. In the ordered phase the particles are attractive and collapse to a line, while in the chaotic phase the particles are repulsive and converge to a regular $n$-simplex. We analytically derive two Lyapunov exponents: an angle exponent that governs departures from the edge of chaos in this particle system, and a gradient exponent that governs the rate of exponential growth or decay of backpropagated gradients. We show through experiments that, remarkably, the final test loss at the end of training is well predicted just by these two exponents at the beginning of training, and that the simultaneous vanishing of these two exponents yields a simple necessary and sufficient condition to achieve minimal test loss.

Geometric Dynamics of Signal Propagation Predict Trainability of Transformers

TL;DR

The paper addresses how deep transformers propagate forward signals and backward gradients from random initialization. By modeling token geometry with the dot-product matrix and treating token representations as a system of interacting particles, it derives deterministic layerwise update maps for forward propagation and, separately, gradient propagation, introducing two Lyapunov exponents: for angular dynamics and for gradient growth. The intersection of forward and backward phase boundaries (edge-of-chaos where and vanishing/exploding gradients where ) yields a simple, necessary-and-sufficient condition for trainability, allowing prediction of final test loss from initialization. The results provide a principled initialization scheme and a broad framework for analyzing deep architectures beyond pure attention, with practical implications for improving trainability in deep transformer models.

Abstract

We investigate forward signal propagation and gradient back propagation in deep, randomly initialized transformers, yielding simple necessary and sufficient conditions on initialization hyperparameters that ensure trainability of deep transformers. Our approach treats the evolution of the representations of tokens as they propagate through the transformer layers in terms of a discrete time dynamical system of interacting particles. We derive simple update equations for the evolving geometry of this particle system, starting from a permutation symmetric simplex. Our update equations show that without MLP layers, this system will collapse to a line, consistent with prior work on rank collapse in transformers. However, unlike prior work, our evolution equations can quantitatively track particle geometry in the additional presence of nonlinear MLP layers, and it reveals an order-chaos phase transition as a function of initialization hyperparameters, like the strength of attentional and MLP residual connections and weight variances. In the ordered phase the particles are attractive and collapse to a line, while in the chaotic phase the particles are repulsive and converge to a regular -simplex. We analytically derive two Lyapunov exponents: an angle exponent that governs departures from the edge of chaos in this particle system, and a gradient exponent that governs the rate of exponential growth or decay of backpropagated gradients. We show through experiments that, remarkably, the final test loss at the end of training is well predicted just by these two exponents at the beginning of training, and that the simultaneous vanishing of these two exponents yields a simple necessary and sufficient condition to achieve minimal test loss.
Paper Structure (18 sections, 37 equations, 7 figures)

This paper contains 18 sections, 37 equations, 7 figures.

Figures (7)

  • Figure 1: Schematic for the layerwise map of the transformer where $t$ is a layer index. The norm and MLP blocks operate tokenwise (i.e. act on the vector components) while the attention block operates on all of the tokens.
  • Figure 2: We show detailed agreement between our analytic theory (red curves, black arrows) and numerical simulation (blue histograms). The top row shows the token dynamics in the ordered phase where they collapse onto a line. The bottom row shows the token dynamics in the chaotic phase where they self-organize into an $n$-simplex. Our first column shows the vector field $F$ over the space of token norms $q/d$ and cosine angles $p/q$. The red curve traces 16 iterations of $F$ with $\alpha_M = \alpha_A = 8^{-1/2}$ and represents 16 layers of a transformer. We show stable (unstable) fixed points as green (red) octagons. In the second column we plot numerical distributions of token norms (blue), $q/d$, along with the analytic prediction of the expected norm (red). In the third column we similarly compare numerics to analytic predictions for the token angle, $p/q$. Our numerical simulations involve signal propagation in $16$ layer transformers with $n=256$ tokens evolving in embedding dimension $d=64$.
  • Figure 3: Statistics of tokens recover permutation symmetry (right) when it is explicitly violated in the initial conditions. We show that \ref{['ass:permutation_symmetry']} is robust by explicitly breaking it in the form of \ref{['eq:rsb_test']} so that $p_1 = 1.5 p_2 = .75d$ initially. We then compute the average $p_1 / p_2$ at several depths in a transformer, which converges to 1 for deep enough models, implying a restoration towards the form assumed.
  • Figure 4: We show strong agreement between our analytic calculations (left column) and numerical calculations (right column) for the token angle exponent $\lambda_a$. The first row shows the phase diagram in terms of the strength of the non-residual branch $\alpha = \alpha_M = \alpha_A$ and the standard deviation of MLP weights, $\sigma_w$. The color depicts the token angle Lyapunov exponent $\lambda_a$, with a red positive (blue negative) value corresponding to the chaotic (collapsed) regime. As $\alpha$ increases, collapse due to attention strengthens, and so $\sigma_w$ must increase as well so that chaos due to the MLP can counteract attentional collapse and maintain dynamics at the edge of chaos with $\lambda_a = 0$. The second row shows $\lambda_a$ as a function of $\alpha_A$ and $\alpha_M$ with a fixed $\sigma_w=2$. Similarly, as $\alpha_A$ increases, so must $\alpha_M$ to maintain $\lambda_a=0$, in order to balance stronger attentional collapse with stronger MLP chaotic amplification.
  • Figure 5: The gradient exponent $\lambda_g$ in a deep transformer in terms of $\sigma_w$ and $\alpha = \alpha_A = \alpha_M$ (first row) and for a fixed $\sigma_w = 2$ varying $\alpha_M, \alpha_A$ (second row). The left column shows the analytic calculation while the right shows the numerical one demonstrating a substantial agreement. Compared to the angle exponent $\lambda_a$, demanding $\lambda_g = 0$ requires a smaller $\sigma_w$, but a similar ratio of $\alpha_A$ to $\alpha_M$.
  • ...and 2 more figures