Table of Contents
Fetching ...

Scan and Snap: Understanding Training Dynamics and Token Composition in 1-layer Transformer

Yuandong Tian, Yiping Wang, Beidi Chen, Simon Du

TL;DR

This work provides the first rigorous SGD-dynamics analysis of a 1-layer, position-encoding-free Transformer for next-token prediction under cross-entropy loss. It introduces a Y/Z reparameterization to study pairwise token_logits and derives a scan-and-snap behavior: self-attention learns to emphasize distinct, highly co-occurring tokens (frequency and discriminative biases) and then undergoes a phase transition that fixes token combinations, controlled by learning-rate settings. The theory explains why attention becomes sparse without collapsing to a single token, and predicts a two-stage evolution (scanning then snapping) validated by synthetic experiments and WikiText data. The results offer a mechanistic lens on how Transformer representations emerge during pretraining and motivate extensions to multi-layer architectures and more complex data distributions.

Abstract

Transformer architecture has shown impressive performance in multiple research domains and has become the backbone of many neural network models. However, there is limited understanding on how it works. In particular, with a simple predictive loss, how the representation emerges from the gradient \emph{training dynamics} remains a mystery. In this paper, for 1-layer transformer with one self-attention layer plus one decoder layer, we analyze its SGD training dynamics for the task of next token prediction in a mathematically rigorous manner. We open the black box of the dynamic process of how the self-attention layer combines input tokens, and reveal the nature of underlying inductive bias. More specifically, with the assumption (a) no positional encoding, (b) long input sequence, and (c) the decoder layer learns faster than the self-attention layer, we prove that self-attention acts as a \emph{discriminative scanning algorithm}: starting from uniform attention, it gradually attends more to distinct key tokens for a specific next token to be predicted, and pays less attention to common key tokens that occur across different next tokens. Among distinct tokens, it progressively drops attention weights, following the order of low to high co-occurrence between the key and the query token in the training set. Interestingly, this procedure does not lead to winner-takes-all, but decelerates due to a \emph{phase transition} that is controllable by the learning rates of the two layers, leaving (almost) fixed token combination. We verify this \textbf{\emph{scan and snap}} dynamics on synthetic and real-world data (WikiText).

Scan and Snap: Understanding Training Dynamics and Token Composition in 1-layer Transformer

TL;DR

This work provides the first rigorous SGD-dynamics analysis of a 1-layer, position-encoding-free Transformer for next-token prediction under cross-entropy loss. It introduces a Y/Z reparameterization to study pairwise token_logits and derives a scan-and-snap behavior: self-attention learns to emphasize distinct, highly co-occurring tokens (frequency and discriminative biases) and then undergoes a phase transition that fixes token combinations, controlled by learning-rate settings. The theory explains why attention becomes sparse without collapsing to a single token, and predicts a two-stage evolution (scanning then snapping) validated by synthetic experiments and WikiText data. The results offer a mechanistic lens on how Transformer representations emerge during pretraining and motivate extensions to multi-layer architectures and more complex data distributions.

Abstract

Transformer architecture has shown impressive performance in multiple research domains and has become the backbone of many neural network models. However, there is limited understanding on how it works. In particular, with a simple predictive loss, how the representation emerges from the gradient \emph{training dynamics} remains a mystery. In this paper, for 1-layer transformer with one self-attention layer plus one decoder layer, we analyze its SGD training dynamics for the task of next token prediction in a mathematically rigorous manner. We open the black box of the dynamic process of how the self-attention layer combines input tokens, and reveal the nature of underlying inductive bias. More specifically, with the assumption (a) no positional encoding, (b) long input sequence, and (c) the decoder layer learns faster than the self-attention layer, we prove that self-attention acts as a \emph{discriminative scanning algorithm}: starting from uniform attention, it gradually attends more to distinct key tokens for a specific next token to be predicted, and pays less attention to common key tokens that occur across different next tokens. Among distinct tokens, it progressively drops attention weights, following the order of low to high co-occurrence between the key and the query token in the training set. Interestingly, this procedure does not lead to winner-takes-all, but decelerates due to a \emph{phase transition} that is controllable by the learning rates of the two layers, leaving (almost) fixed token combination. We verify this \textbf{\emph{scan and snap}} dynamics on synthetic and real-world data (WikiText).
Paper Structure (28 sections, 33 theorems, 110 equations, 12 figures, 1 table)

This paper contains 28 sections, 33 theorems, 110 equations, 12 figures, 1 table.

Key Result

Lemma 1

The gradient dynamics of Eqn. eq:objective with batchsize 1 is: Here $P^{\perp}_{{\bm{v}}} := I - {\bm{v}}{\bm{v}}^\top/\|{\bm{v}}\|_2^2$ projects a vector into ${\bm{v}}$'s orthogonal complementary space, $\eta_Y$ and $\eta_Z$ are the learning rates for the decoder layer $Y$ and self-attention layer $Z$, $\boldsymbol{\alpha} := [\alpha_1,\ldots,\alpha_M]^\top \

Figures (12)

  • Figure 1: Overall of our setting. (a) A sequence with contextual tokens $\{x_1,\ldots, x_{T-1}\}$ and query token $x_T$ is fed into 1-layer transformer (self-attention, normalization and decoding) to predict the next token $x_{T+1}$. (b) The definition of sequence classes (Sec. \ref{['sec:data-gen']}). A sequence class specifies the conditional probability $\mathbb{P}(l|m,n)$ of the contextual tokens, given the query token $x_T = m$ and the next token $x_{T+1} = n$. For simplicity, we consider the case that the query token is determined by the next token: $x_T = \psi(x_{T+1})$ (and thus $\mathbb{P}(l|m,n) = \mathbb{P}(l|n)$), while the same query token $m$ may correspond to multiple next tokens (i.e., $\psi^{-1}(m)$ is not unique). We study two kinds of tokens: common tokens (CT) with $\mathbb{P}(l|n) > 0$ for multiple sequence class $n$, and distinct tokens (DT) with $\mathbb{P}(l|n) > 0$ for a single sequence class $n$ only.
  • Figure 2: Overview of the training dynamics of self-attention map. Here $\tilde{c}_{l|m,n} := \mathbb{P}(l|m,n)\exp(z_{ml})$ is the un-normalized attention score (Eqn. \ref{['eq:c_lmn']}). (a) Initialization stage. $z_{ml}(0) = 0$ and $\tilde{c}_{l|m,n} = \mathbb{P}(l|m,n)$. Distinct tokens (Sec. \ref{['sec:data-gen']}) shown in blue, common tokens in yellow. (b) Common tokens (CT) are suppressed ($\dot z_{ml} < 0$, Theorem \ref{['thm:dynamic-property-of-token']}). (c) Winners-take-all stage. Distinct tokens (DT) with large initial value $\tilde{c}_{l|m,n}(0)$ start to dominate the attention map (Sec. \ref{['sec:dyn-QK']}, Theorem \ref{['thm:dyn-fate']}). (d) Once passing the phase transition, i.e., $t \ge t_0 = O(K\ln M / \eta_Y)$, attention appears (almost) frozen (Sec. \ref{['sec:snapping']}) and token composition is fixed in the self-attention layer.
  • Figure 3: Growth factor $\chi_l(t)$ (Theorem \ref{['thm:dyn-fate']}) over time with fixed $\eta_Z = 0.5$ and changing $\eta_Y$. Each solid line is $\chi_l(t)$ and the dotted line with the same color corresponds to the transition time $t_0$ for a given $\eta_Y$.
  • Figure 4: Visualization of ${\bm{c}}_n$ ($n=1,2$) in the training dynamics of 1-layer Transformer using SGD on Syn-Small setting. Top row for query token $n=1$ and bottom row for query token $n=2$. Left: SGD training with $\eta_Y = \eta_Z = 1$. Attention pattern ${\bm{c}}_n$ becomes sparse and concentrated on highest $\mathbb{P}(l|n)$ (rightmost) for each sequence class (Theorem \ref{['thm:dyn-fate']}). Right: SGD training with $\eta_Y = 10$ and $\eta_Z = 1$. With larger $\eta_Y$, convergence becomes faster but the final attention maps are less sparse (Sec. \ref{['sec:snapping']}).
  • Figure 5: Visualization of (part of) ${\bm{c}}_n$ for sequence class $n=1$ in the training dynamics using Adam kingma2014adam on Syn-Small setting. From left to right: $\eta_V=\eta_Z = 0.1, 0.5, 1$. With different learning rate Adam seems to steer self-attention towards different subset of distinct tokens, showing tune-able inductive bias.
  • ...and 7 more figures

Theorems & Definitions (61)

  • Lemma 1: Dynamics of 1-layer Transformer
  • Lemma 2
  • Lemma 3
  • Theorem 1
  • Lemma 4: Self-attention dynamics
  • Theorem 2: Fates of contextual tokens
  • Theorem 3: Growth of distinct tokens
  • Theorem 4: Phase Transition in Training
  • Lemma 4: Dynamics of 1-layer Transformer
  • proof
  • ...and 51 more