Table of Contents
Fetching ...

JoMA: Demystifying Multilayer Transformers via JOint Dynamics of MLP and Attention

Yuandong Tian, Yiping Wang, Zhenyu Zhang, Beidi Chen, Simon Du

TL;DR

JoMA provides a unified framework that merges self-attention with MLP dynamics to analyze training in multilayer Transformers, relaxing prior assumptions such as residual-free and solely linear analyses. It shows that, under nonlinear activations, attention logits follow invariants that yield an initial sparse focus on salient tokens and a later denser attention as training progresses, while linear activation recovers sparse attention consistent with earlier work. By leveraging a hierarchical latent-tree generative model, JoMA qualitatively explains how hierarchical token structure emerges across layers, and validates the theory with WikiText and pre-trained models (OPT/Pythia). The work advances understanding of how self-attention interacts with nonlinear MLPs to produce hierarchical representations, offering insights for analyzing and designing Transformer training dynamics.

Abstract

We propose Joint MLP/Attention (JoMA) dynamics, a novel mathematical framework to understand the training procedure of multilayer Transformer architectures. This is achieved by integrating out the self-attention layer in Transformers, producing a modified dynamics of MLP layers only. JoMA removes unrealistic assumptions in previous analysis (e.g., lack of residual connection) and predicts that the attention first becomes sparse (to learn salient tokens), then dense (to learn less salient tokens) in the presence of nonlinear activations, while in the linear case, it is consistent with existing works that show attention becomes sparse over time. We leverage JoMA to qualitatively explains how tokens are combined to form hierarchies in multilayer Transformers, when the input tokens are generated by a latent hierarchical generative model. Experiments on models trained from real-world dataset (Wikitext2/Wikitext103) and various pre-trained models (OPT, Pythia) verify our theoretical findings. Code can be found in https://github.com/facebookresearch/luckmatters/tree/yuandong3.

JoMA: Demystifying Multilayer Transformers via JOint Dynamics of MLP and Attention

TL;DR

JoMA provides a unified framework that merges self-attention with MLP dynamics to analyze training in multilayer Transformers, relaxing prior assumptions such as residual-free and solely linear analyses. It shows that, under nonlinear activations, attention logits follow invariants that yield an initial sparse focus on salient tokens and a later denser attention as training progresses, while linear activation recovers sparse attention consistent with earlier work. By leveraging a hierarchical latent-tree generative model, JoMA qualitatively explains how hierarchical token structure emerges across layers, and validates the theory with WikiText and pre-trained models (OPT/Pythia). The work advances understanding of how self-attention interacts with nonlinear MLPs to produce hierarchical representations, offering insights for analyzing and designing Transformer training dynamics.

Abstract

We propose Joint MLP/Attention (JoMA) dynamics, a novel mathematical framework to understand the training procedure of multilayer Transformer architectures. This is achieved by integrating out the self-attention layer in Transformers, producing a modified dynamics of MLP layers only. JoMA removes unrealistic assumptions in previous analysis (e.g., lack of residual connection) and predicts that the attention first becomes sparse (to learn salient tokens), then dense (to learn less salient tokens) in the presence of nonlinear activations, while in the linear case, it is consistent with existing works that show attention becomes sparse over time. We leverage JoMA to qualitatively explains how tokens are combined to form hierarchies in multilayer Transformers, when the input tokens are generated by a latent hierarchical generative model. Experiments on models trained from real-world dataset (Wikitext2/Wikitext103) and various pre-trained models (OPT, Pythia) verify our theoretical findings. Code can be found in https://github.com/facebookresearch/luckmatters/tree/yuandong3.
Paper Structure (22 sections, 18 theorems, 79 equations, 14 figures, 1 table)

This paper contains 22 sections, 18 theorems, 79 equations, 14 figures, 1 table.

Key Result

Theorem 1

Let ${\bm{v}}_k := U_C^\top {\bm{w}}_k$, then the dynamics of Eqn. eq:one-layer-dyn satisfies the invariants: Under zero initialization (${\bm{w}}_k(0) = 0$, ${\bm{z}}_m(0) = 0$), then the time-independent constant ${\bm{c}} = 0$.

Figures (14)

  • Figure 1: (a) Overview of JoMA framework. Using the invariant of training dynamics, the self-attention layer and the lower layer of MLP can be merged together to yield a MLP layer with modified dynamics (Theorem \ref{['thm:joma']}), which explains the behaviors of attention in linear (Sec. \ref{['sec:linear-attns']}) and nonlinear (Sec. \ref{['sec:nonlinear-dynamics-attns']}) MLP activation $\phi$, as well as hierarchical concept learning in multilayer cases (Sec. \ref{['sec:hierarchical-dynamics']}). (b) Problem setting. JoMA frameworks support different kind of attentions, including linear attention $b_l := x_l z_{ql}$, exp attention $b_l := x_l e^{z_{ql}}/A$ and softmax $b_l := x_l e^{z_{ql}} / \sum_l x_l e^{z_{ql}}$.
  • Figure 2: Test of training dynamics with linear MLP activation ($\phi(x)=x$) under softmax attention. Left Two: The distribution of ${\bm{x}}$ smoothly transits over different class labels. Right Two: The distribution of ${\bm{x}}$ over different classes are randomly generated. In both cases, the estimated $\hat{{\bm{z}}}_m(t)$ by the first integral (Theorem \ref{['thm:joma']}), despite assumptions on $\bar{{\bm{b}}}_m$, shows high correlation with the ground truth self-attention logits ${\bm{z}}_m(t)$, while its two components $\hat{{\bm{z}}}_{m1}(t) := \frac{1}{2}\sum_k {\bm{v}}_k^2(t)$ and $\hat{{\bm{z}}}_{m2}(t) := -\frac{1}{2}\sum_k \|{\bm{v}}_k(t)\|_2^2\bar{{\bm{b}}}_m$ do not.
  • Figure 3: Growth of different components in ${\bm{v}}_0(t)$ (First few components of the first column of $V(t)$) in linear MLP activation and softmax attention. As predicted by Sec. \ref{['sec:linear-attns']}, after convergence, only some components of ${\bm{v}}_0$ grows while the remaining components is saturated after initial growing, consistent with Theorem \ref{['thm:linear-dynamics']} even if it is derived from JoMA's approximation in Theorem \ref{['thm:joma']}. Each node $k$ (and thus ${\bm{w}}_k$) receives back-propagated gradient from $k$-th class via cross-entropy loss.
  • Figure 4: Dynamics of nonlinear MLP with self-attention components included (Eqn. \ref{['eq:nonlinear-dynamics-self-attn']}). Left: Training dynamics (color indicating training steps). The salient components (i.e., components with large magnitude in $\boldsymbol{\mu}$) of ${\bm{v}}(t)$ are learned first, followed by non-salient ones. Right: Entropy of the attention (i.e., $\mathrm{entropy}(\mathrm{softmax}({\bm{v}}^2))$) drops when salient components are learned first, and then rebounces when other components catch up.
  • Figure 5: (a) Hierarchical binary tree generative models. Except for $y_0$ that is the observable label of a sequence and can take $D$ discrete labels, all latent variables follow binomial distribution. A binary leaf variable $y_l = 1$ indicates that token $l$ appears in the sequence. (b) Attention dynamics in multi-layer setting. There is a strong co-occurrence between the query $m$ and the token $l$, but a weak co-occurrence between $m$ and $l'$. As a result, $m$ associates with $l$ first, and eventually associates with $l'$, even if they co-occur weakly, according to Theorem \ref{['thm:convergence-speed']}. (c) If there exists an additional layer $y_\beta$ and $y_{\beta'}$ in the latent hierarchy, the association $m$-$l$ and $m'$-$l'$ will be learned first due to their high co-occurrence. Once the lower hierarchy gets learned and some hidden nodes in MLP represents $y_\beta$ and $y_{\beta'}$ (see Sec. \ref{['sec:val-align']} for experimental validation), on the next level, $y_\beta$ and $y_{\beta'}$ shows strong co-occurrence and gets picked up by the self-attention mechanism to form even higher level features. In contrast, the association of $l'$-$m$ is much slower and does not affect latent hierarchy learning, showing that self-attention mechanism is adaptive to the structure of data distribution.
  • ...and 9 more figures

Theorems & Definitions (33)

  • Theorem 1
  • Theorem 2: Linear Dynamics with Self-attention
  • Theorem 3: Dynamics of nonlinear activation with uniform attention
  • Theorem 4: Convergence speed of salient vs. non-salient components
  • Theorem 5: Token Co-occurrence in $\hblt{}(\rho)$
  • Theorem 5
  • proof
  • Theorem 5: Linear Dynamics with Self-attention
  • proof
  • Lemma 1: Expectation of Hyperplane function under Isotropic distribution
  • ...and 23 more