Transformers learn through gradual rank increase
Enric Boix-Adsera, Etai Littwin, Emmanuel Abbe, Samy Bengio, Joshua Susskind
TL;DR
This work investigates transformer training dynamics under a diagonal-weight simplification, proving that gradient flow exhibits stagewise, incremental learning where the learned perturbation from initialization increases in rank by at most one per stage, with each stage lasting $\Theta(\log(1/\alpha))$. The authors provide a general theorem for networks depending on $f_{NN}(x;u,v)=h(x;u\odot v)$ and instantiate it to diagonal attention heads, alongside intuition based on a conservation law that reduces the dynamics to tracking $\boldsymbol w(t)=\boldsymbol u(t)+\boldsymbol v(t)$. They corroborate theory with experiments on vision and language transformers, observing gradual rank growth of weight perturbations even beyond the simplifying assumptions, and highlight practical parallels to low-rank fine-tuning methods like LoRA. The work highlights a potential mechanism for efficient adaptation via low-rank perturbations and opens avenues to leverage this behavior for improved training and fine-tuning of large transformers.
Abstract
We identify incremental learning dynamics in transformers, where the difference between trained and initial weights progressively increases in rank. We rigorously prove this occurs under the simplifying assumptions of diagonal weight matrices and small initialization. Our experiments support the theory and also show that phenomenon can occur in practice without the simplifying assumptions.
