Table of Contents
Fetching ...

From Condensation to Rank Collapse: A Two-Stage Analysis of Transformer Training Dynamics

Zheng-An Chen, Tao Luo

TL;DR

This work analyzes transformer training dynamics under small initialization through gradient flow, revealing a two-stage process: first, asymmetric weight perturbations induce condensation of outer parameters toward a target direction; second, after these parameters stabilize, the key-query matrices activate and undergo rank-collapse, refining representations. The framework combines rigorous blow-up and condensation results with a dynamics-separation regime that yields linear key-query dynamics governed by a matrix $\mathbf{F}$, enabling rank-1 collapse when $\mathbf{F}$ has a unique largest singular value. Experimental validation on synthetic data and WikiText demonstrates the two-stage trajectory, including observable condensation in outer weights and subsequent rank collapse in the attention components. These findings offer a principled mechanism for implicit regularization in overparameterized transformers and a foundation for understanding why small initialization can promote robust generalization and efficient learning in language models.

Abstract

Although transformer-based models have shown exceptional empirical performance, the fundamental principles governing their training dynamics are inadequately characterized beyond configuration-specific studies. Inspired by empirical evidence showing improved reasoning capabilities under small initialization scales in language models, we employ the gradient flow analytical framework established in [Zhou et al. NeurIPS 2022] to systematically investigate linearized Transformer training dynamics. Our theoretical analysis dissects the dynamics of attention modules into two distinct stages. In the first stage, asymmetric weight perturbations from random initialization sustain non-degenerate gradient dynamics in parameter matrices, facilitating systematic escape from small initialization regimes. Subsequently, these matrices undergo condensation, progressively aligning toward the target orientation. In the second stage, the previously static key-query matrices actively participate in training, driving the normalized matrices toward asymptotic rank collapse. This two-stage framework generalizes classical directional convergence results.

From Condensation to Rank Collapse: A Two-Stage Analysis of Transformer Training Dynamics

TL;DR

This work analyzes transformer training dynamics under small initialization through gradient flow, revealing a two-stage process: first, asymmetric weight perturbations induce condensation of outer parameters toward a target direction; second, after these parameters stabilize, the key-query matrices activate and undergo rank-collapse, refining representations. The framework combines rigorous blow-up and condensation results with a dynamics-separation regime that yields linear key-query dynamics governed by a matrix , enabling rank-1 collapse when has a unique largest singular value. Experimental validation on synthetic data and WikiText demonstrates the two-stage trajectory, including observable condensation in outer weights and subsequent rank collapse in the attention components. These findings offer a principled mechanism for implicit regularization in overparameterized transformers and a foundation for understanding why small initialization can promote robust generalization and efficient learning in language models.

Abstract

Although transformer-based models have shown exceptional empirical performance, the fundamental principles governing their training dynamics are inadequately characterized beyond configuration-specific studies. Inspired by empirical evidence showing improved reasoning capabilities under small initialization scales in language models, we employ the gradient flow analytical framework established in [Zhou et al. NeurIPS 2022] to systematically investigate linearized Transformer training dynamics. Our theoretical analysis dissects the dynamics of attention modules into two distinct stages. In the first stage, asymmetric weight perturbations from random initialization sustain non-degenerate gradient dynamics in parameter matrices, facilitating systematic escape from small initialization regimes. Subsequently, these matrices undergo condensation, progressively aligning toward the target orientation. In the second stage, the previously static key-query matrices actively participate in training, driving the normalized matrices toward asymptotic rank collapse. This two-stage framework generalizes classical directional convergence results.

Paper Structure

This paper contains 33 sections, 12 theorems, 109 equations, 5 figures, 2 tables.

Key Result

Proposition 1

Given a binary dataset $\{({\bm{X}}_i, y_i)\}_{i=1}^n$, we define condensation direction $\bm{v}$ and rescaled time coordinate $\bar{t}$ as follows: Then, normalized parameters $\bar{\bm{\theta}}$ follow leading-order dynamics after rescaling:

Figures (5)

  • Figure 1: (a) Evolution of cosine similarity matrices for outer and attention parameters. The training process is partitioned into three stages: Condensation (Stage 1), Key-Query rank collapse (Stage 2), and a further training stage. Stage transitions are identified by plateaus in the loss curve and structural shifts in these matrices. (b) The relative change of norms between attention and outer parameters. The gray dashed line marks the onset of Stage 2, where updates to the attention parameters begin to dominate. (c) Evolution of the effective rank for both parameter groups, tracking the change in their intrinsic dimensionality throughout training.
  • Figure 2: (a) Proportion of satisfied conditions in Assumption \ref{['assump::FinalStageCond']}, measured as $\frac{|A_1|}{d_m}$ and $\frac{|A_2|}{d_m^2}$ (Definitions of $A_1$ and $A_2$ refer to Sec. \ref{['subsec::Experimental_Setting']}). (b) Similarity between singular vectors of two adjacent time steps. For example, let ${\bm{U}}_t \boldsymbol{\Sigma}_t {\bm{V}}_t$ and ${\bm{U}}_{t+1} \boldsymbol{\Sigma}_{t+1} {\bm{V}}_{t+1}$ be the singular value decompositions of parameter matrix ${\bm{W}}_t$ and ${\bm{W}}_{t+1}$. The similarity is defined as $\frac{1}{d_m} \sum_{i=1}^{d_m} \cos({\bm{u}}_t^i, {\bm{u}}_{t+1}^i)$ (or $\frac{1}{d_m} \sum_{i=1}^{d_m} \cos(\boldsymbol{v}_t^i, \boldsymbol{v}_{t+1}^i)$). (c) Frobenius norms of parameter groups.
  • Figure 3: Evolution of cosine similarity between parameter of the two-layer transformer on WikiText dataset. Training dynamics also show a similar three-phase characteristic. Superscripts are used to indicate parameters of different layers, and subscripts indicate different parameters within a layer. For example, ${\bm{W}}^{1}_V$ represents the value matrix of the first layer.
  • Figure 4: (a) Evolution of cosine similarity matrices for outer and attention parameters. The training process is partitioned into three stages: Condensation (Stage 1), Key-Query rank collapse (Stage 2), and a further training stage. Stage transitions are identified by plateaus in the loss curve and structural shifts in these matrices. (b) The relative change of norms between attention and outer parameters. The gray dashed line marks the onset of Stage 2, where updates to the attention parameters begin to dominate. (c) Evolution of the effective rank for both parameter groups, tracking the change in their intrinsic dimensionality throughout training.
  • Figure 5: (a) Proportion of satisfied conditions in Assumption \ref{['assump::FinalStageCond']}, measured as $\frac{|A_1|}{d_m}$ and $\frac{|A_2|}{d_m^2}$. (b) Similarity between singular vectors of two adjacent time steps. (c) Frobenius norms of parameter groups.

Theorems & Definitions (29)

  • Definition 1: Condensation
  • Definition 2: Asymptotic rank collapse
  • Definition 3: One-layer transformer
  • Proposition 1: Effective training dynamics
  • Proposition 2: Conservation laws
  • Definition 4: Non-degenerate initialization
  • Theorem 1: Blow-up in finite time
  • Theorem 2: Condensation
  • Proposition 3: Effective dynamics during dynamics separation stage
  • Theorem 3: Asymptotic rank collapse
  • ...and 19 more