Table of Contents
Fetching ...

In-Context Linear Regression Demystified: Training Dynamics and Mechanistic Interpretability of Multi-Head Softmax Attention

Jianliang He, Xintian Pan, Siyu Chen, Zhuoran Yang

TL;DR

This work analyzes how a one-layer, multi-head softmax Transformer learns in-context linear regression on Gaussian data. The authors reveal consistent emergent patterns in the learned weights: a diagonal, homogeneous KQ structure and a last-entry, zero-sum OV pattern, across heads, which enable the model to implement a debiased gradient-descent predictor. They provide a mechanistic theory linking gradient dynamics to the observed patterns, showing that multi-head attention closely tracks a debiased GD predictor and approaches Bayesian optimal performance up to a proportional factor, while softmax attention generalizes to longer sequences better than linear attention. The study extends to anisotropic covariates and multi-task regression, where heads allocate to tasks or exhibit a superposition regime, demonstrating the model’s capacity to distribute representational power adaptively. Overall, the results offer a principled understanding of in-context learning as an aggregated effect of architecture and data distribution, with implications for interpretability and broader application.

Abstract

We study how multi-head softmax attention models are trained to perform in-context learning on linear data. Through extensive empirical experiments and rigorous theoretical analysis, we demystify the emergence of elegant attention patterns: a diagonal and homogeneous pattern in the key-query (KQ) weights, and a last-entry-only and zero-sum pattern in the output-value (OV) weights. Remarkably, these patterns consistently appear from gradient-based training starting from random initialization. Our analysis reveals that such emergent structures enable multi-head attention to approximately implement a debiased gradient descent predictor -- one that outperforms single-head attention and nearly achieves Bayesian optimality up to proportional factor. Furthermore, compared to linear transformers, the softmax attention readily generalizes to sequences longer than those seen during training. We also extend our study to scenarios with anisotropic covariates and multi-task linear regression. In the former, multi-head attention learns to implement a form of pre-conditioned gradient descent. In the latter, we uncover an intriguing regime where the interplay between head number and task number triggers a superposition phenomenon that efficiently resolves multi-task in-context learning. Our results reveal that in-context learning ability emerges from the trained transformer as an aggregated effect of its architecture and the underlying data distribution, paving the way for deeper understanding and broader applications of in-context learning.

In-Context Linear Regression Demystified: Training Dynamics and Mechanistic Interpretability of Multi-Head Softmax Attention

TL;DR

This work analyzes how a one-layer, multi-head softmax Transformer learns in-context linear regression on Gaussian data. The authors reveal consistent emergent patterns in the learned weights: a diagonal, homogeneous KQ structure and a last-entry, zero-sum OV pattern, across heads, which enable the model to implement a debiased gradient-descent predictor. They provide a mechanistic theory linking gradient dynamics to the observed patterns, showing that multi-head attention closely tracks a debiased GD predictor and approaches Bayesian optimal performance up to a proportional factor, while softmax attention generalizes to longer sequences better than linear attention. The study extends to anisotropic covariates and multi-task regression, where heads allocate to tasks or exhibit a superposition regime, demonstrating the model’s capacity to distribute representational power adaptively. Overall, the results offer a principled understanding of in-context learning as an aggregated effect of architecture and data distribution, with implications for interpretability and broader application.

Abstract

We study how multi-head softmax attention models are trained to perform in-context learning on linear data. Through extensive empirical experiments and rigorous theoretical analysis, we demystify the emergence of elegant attention patterns: a diagonal and homogeneous pattern in the key-query (KQ) weights, and a last-entry-only and zero-sum pattern in the output-value (OV) weights. Remarkably, these patterns consistently appear from gradient-based training starting from random initialization. Our analysis reveals that such emergent structures enable multi-head attention to approximately implement a debiased gradient descent predictor -- one that outperforms single-head attention and nearly achieves Bayesian optimality up to proportional factor. Furthermore, compared to linear transformers, the softmax attention readily generalizes to sequences longer than those seen during training. We also extend our study to scenarios with anisotropic covariates and multi-task linear regression. In the former, multi-head attention learns to implement a form of pre-conditioned gradient descent. In the latter, we uncover an intriguing regime where the interplay between head number and task number triggers a superposition phenomenon that efficiently resolves multi-task in-context learning. Our results reveal that in-context learning ability emerges from the trained transformer as an aggregated effect of its architecture and the underlying data distribution, paving the way for deeper understanding and broader applications of in-context learning.

Paper Structure

This paper contains 86 sections, 9 theorems, 227 equations, 36 figures, 1 table.

Key Result

Proposition 4.1

Consider an $H$-head attention model parameterized by eq:simple_model with dimension $d\in\mathbb{Z}^+$ and sample size $L\in\mathbb{Z}^+$. Suppose that $d>\log L$ and parameters $(\omega,\mu)\subseteq\mathbb{R}^{2H}$ satisfies that $\|\omega\|_\infty\lesssim\sqrt{\log L/d}$ and $\|\mu\|_\infty\less where $\exp(\cdot)$ is applied element-wisely.

Figures (36)

  • Figure 1: Learned Weights of 2-Head Attention.
  • Figure 2: Training Dynamic of Attention Weights.
  • Figure 4: Illustration of the derivation from the full multi-head attention architecture in \ref{['eq:std_TF']} to the simplified model. In the graph, we show how the predictor $\widehat{y}_q$ is generated by the transformer, based on the token embeddings, consolidated KQ and OV circuits, and the read-out function. We use $*$ to denote the ineffective parameters due to the read-out function or the zero value in the query token. In particular, the token embedding is given in \ref{['eq:embed']} and the read-out function extracts the $(L+1, d+1)$-th entry of the transformer output. Thus, $\widehat{y}_q$ only depends on the KQ and OV matrices through some specific submatrices, as shown in the figure.
  • Figure 5: Learned Weights of Single-head Attention.
  • Figure 6: 4-Head KQ Dynamics.
  • ...and 31 more figures

Theorems & Definitions (24)

  • Proposition 4.1: Informal
  • Theorem 4.2
  • Definition 5.1: Multi-task Linear Model
  • Definition 5.2
  • Definition B.1: Initialization
  • Proposition C.1: Formal Statement of Proposition \ref{['prop:informal_approx_loss']}
  • proof : Proof of Proposition \ref{['thm:approx_loss']}
  • Lemma C.2
  • proof : Proof of Lemma \ref{['lem:2stein']}
  • Lemma C.3
  • ...and 14 more