JoMA: Demystifying Multilayer Transformers via JOint Dynamics of MLP and Attention
Yuandong Tian, Yiping Wang, Zhenyu Zhang, Beidi Chen, Simon Du
TL;DR
JoMA provides a unified framework that merges self-attention with MLP dynamics to analyze training in multilayer Transformers, relaxing prior assumptions such as residual-free and solely linear analyses. It shows that, under nonlinear activations, attention logits follow invariants that yield an initial sparse focus on salient tokens and a later denser attention as training progresses, while linear activation recovers sparse attention consistent with earlier work. By leveraging a hierarchical latent-tree generative model, JoMA qualitatively explains how hierarchical token structure emerges across layers, and validates the theory with WikiText and pre-trained models (OPT/Pythia). The work advances understanding of how self-attention interacts with nonlinear MLPs to produce hierarchical representations, offering insights for analyzing and designing Transformer training dynamics.
Abstract
We propose Joint MLP/Attention (JoMA) dynamics, a novel mathematical framework to understand the training procedure of multilayer Transformer architectures. This is achieved by integrating out the self-attention layer in Transformers, producing a modified dynamics of MLP layers only. JoMA removes unrealistic assumptions in previous analysis (e.g., lack of residual connection) and predicts that the attention first becomes sparse (to learn salient tokens), then dense (to learn less salient tokens) in the presence of nonlinear activations, while in the linear case, it is consistent with existing works that show attention becomes sparse over time. We leverage JoMA to qualitatively explains how tokens are combined to form hierarchies in multilayer Transformers, when the input tokens are generated by a latent hierarchical generative model. Experiments on models trained from real-world dataset (Wikitext2/Wikitext103) and various pre-trained models (OPT, Pythia) verify our theoretical findings. Code can be found in https://github.com/facebookresearch/luckmatters/tree/yuandong3.
