Table of Contents
Fetching ...

Transformers learn through gradual rank increase

Enric Boix-Adsera, Etai Littwin, Emmanuel Abbe, Samy Bengio, Joshua Susskind

TL;DR

This work investigates transformer training dynamics under a diagonal-weight simplification, proving that gradient flow exhibits stagewise, incremental learning where the learned perturbation from initialization increases in rank by at most one per stage, with each stage lasting $\Theta(\log(1/\alpha))$. The authors provide a general theorem for networks depending on $f_{NN}(x;u,v)=h(x;u\odot v)$ and instantiate it to diagonal attention heads, alongside intuition based on a conservation law that reduces the dynamics to tracking $\boldsymbol w(t)=\boldsymbol u(t)+\boldsymbol v(t)$. They corroborate theory with experiments on vision and language transformers, observing gradual rank growth of weight perturbations even beyond the simplifying assumptions, and highlight practical parallels to low-rank fine-tuning methods like LoRA. The work highlights a potential mechanism for efficient adaptation via low-rank perturbations and opens avenues to leverage this behavior for improved training and fine-tuning of large transformers.

Abstract

We identify incremental learning dynamics in transformers, where the difference between trained and initial weights progressively increases in rank. We rigorously prove this occurs under the simplifying assumptions of diagonal weight matrices and small initialization. Our experiments support the theory and also show that phenomenon can occur in practice without the simplifying assumptions.

Transformers learn through gradual rank increase

TL;DR

This work investigates transformer training dynamics under a diagonal-weight simplification, proving that gradient flow exhibits stagewise, incremental learning where the learned perturbation from initialization increases in rank by at most one per stage, with each stage lasting . The authors provide a general theorem for networks depending on and instantiate it to diagonal attention heads, alongside intuition based on a conservation law that reduces the dynamics to tracking . They corroborate theory with experiments on vision and language transformers, observing gradual rank growth of weight perturbations even beyond the simplifying assumptions, and highlight practical parallels to low-rank fine-tuning methods like LoRA. The work highlights a potential mechanism for efficient adaptation via low-rank perturbations and opens avenues to leverage this behavior for improved training and fine-tuning of large transformers.

Abstract

We identify incremental learning dynamics in transformers, where the difference between trained and initial weights progressively increases in rank. We rigorously prove this occurs under the simplifying assumptions of diagonal weight matrices and small initialization. Our experiments support the theory and also show that phenomenon can occur in practice without the simplifying assumptions.
Paper Structure (53 sections, 17 theorems, 89 equations, 19 figures, 1 algorithm)

This paper contains 53 sections, 17 theorems, 89 equations, 19 figures, 1 algorithm.

Key Result

Theorem 1.1

Let $f_{\mathsf{NN}}$ be a network of the form eq:diag-network-def, and suppose that the weights are initialized very small: i.e., the entries of ${\boldsymbol u},{\boldsymbol v}$ are initialized on the order $\Theta(\alpha)$ for some small $\alpha > 0$. Then the dynamics of gradient flow training e

Figures (19)

  • Figure 1: (a) Loss versus rescaled time in the toy task of learning an attention head with diagonal weights, for various initialization scales $\alpha$. The loss curves converge as $\alpha \to 0$ to a curve with stagewise loss plateaus and sharp decreases, as predicted by the theory; some stagewise learning behavior is already clear with $\alpha = 0.01$. (b) Each line shows the evolution of one of the entries of $\mathrm{diag}({\boldsymbol w}_Q)\mathrm{diag}({\boldsymbol w}_K)$ and $\mathrm{diag}({\boldsymbol w}_V)\mathrm{diag}({\boldsymbol w}_O)$ over rescaled time, demonstrating that the rank of these matrices increases incrementally; see Appendix \ref{['app:toy-assumptions']} for experimental details and further results.
  • Figure 2: Validation of assumptions on the toy model of learning a single attention head. (a) Assumption \ref{['ass:strict-local-minimum-untied']}: weights perturbed at a random time during training (solid lines) tend back to the near-stationary point (dashed lines). (b) Assumption \ref{['ass:robust-dynamics-untied']}: weights perturbed at the beginning of a stage (solid lines) have same nonlinear evolution as without perturbation (dashed lines). Details of these experiments and further validations are provided in Appendix \ref{['app:toy-assumptions']}.
  • Figure 3: Stable rank of $\Delta{\boldsymbol W}_K {\boldsymbol W}_Q^\top$ (blue) and $\Delta{\boldsymbol W}_V {\boldsymbol W}_O^\top$ (orange) on an arbitrary chosen layer throughout training for four different pairs of networks and tasks. The stable rank of a matrix ${\boldsymbol W}$ is defined as $\|{\boldsymbol W}\|^2_F / \|{\boldsymbol W}\|_2^2$, and gives a smooth approximation of the rank. Mean and standard deviation (shaded area) are computed across all heads in each attention layer. Full details and results are in Appendix \ref{['appendix:d1']}.
  • Figure 4: Spectrum of the weight perturbation $\Delta{\boldsymbol W}_K {\boldsymbol W}_Q^\top$ vs. initialization in a vision transformer trained on CIFAR-10, using Adam and default initialization scale, in random self-attention heads in different layers. The learned perturbation exhibits extreme low-rank bias post-training even in default initialization scales. Analogous plots for CIFAR-100 and ImageNet are in Appendix \ref{['appendix:d1']}.
  • Figure 5: Training a vision transformer on CIFAR-10 using Adam, while varying the initialization scale (unit scale indicates default initialization). Plotted are the evolution of the eigenvalues of $\Delta{\boldsymbol W}_K {\boldsymbol W}_Q^\top$ (a) - (c) and $\Delta{\boldsymbol W}_V {\boldsymbol W}_O^\top$ (d) - (f) in a random self-attention head in the second layer throughout training. Incremental learning dynamics and a low-rank bias are evident for all scales, albeit more pronounced at smaller initialization scales.
  • ...and 14 more figures

Theorems & Definitions (33)

  • Theorem 1.1: Informal statement of incremental learning dynamics
  • Definition 3.1
  • Example 3.2: Transformer with diagonal weights
  • Theorem 4.1: Incremental dynamics at small initialization
  • Lemma 4.2: Conservation law
  • proof
  • Theorem C.4: Restatement of Theorem \ref{['thm:diag-2-greedy-untied']}
  • Lemma C.5: Stability of active variables during part (A) of dynamics
  • proof
  • Lemma C.6: Log-scale approximation is correct during part (A)
  • ...and 23 more