Linear attention is (maybe) all you need (to understand transformer optimization)
Kwangjun Ahn, Xiang Cheng, Minhak Song, Chulhee Yun, Ali Jadbabaie, Suvrit Sra
TL;DR
The paper investigates why Transformer training is difficult and proposes a minimal linear attention model as an abstraction that reproduces known optimization phenomena. By training a shallow linear Transformer on random linear regression tasks, the authors demonstrate the model captures heavy-tailed gradient noise, ill-conditioned landscapes, directional and generalized smoothness, and the advantage of adaptive optimizers over SGD. They further analyze the impact of data distribution and network depth, showing these features intensify with heavier-tailed data and more layers. The work provides a practical, tractable framework for theory development and optimization algorithm design for Transformers.
Abstract
Transformer training is notoriously difficult, requiring a careful design of optimizers and use of various heuristics. We make progress towards understanding the subtleties of training Transformers by carefully studying a simple yet canonical linearized shallow Transformer model. Specifically, we train linear Transformers to solve regression tasks, inspired by J.~von Oswald et al.~(ICML 2023), and K.~Ahn et al.~(NeurIPS 2023). Most importantly, we observe that our proposed linearized models can reproduce several prominent aspects of Transformer training dynamics. Consequently, the results obtained in this paper suggest that a simple linearized Transformer model could actually be a valuable, realistic abstraction for understanding Transformer optimization.
