Table of Contents
Fetching ...

How Transformers Utilize Multi-Head Attention in In-Context Learning? A Case Study on Sparse Linear Regression

Xingwu Chen, Lei Zhao, Difan Zou

TL;DR

The paper addresses how transformers perform in-context learning for sparse linear regression and discovers a two-phase mechanism: the first layer uses multiple heads to preprocess the context, while later layers perform gradient-descent-like optimization with a single head. It formalizes a preprocess-then-optimize algorithm that can be implemented by linear-attention transformers and proves it can achieve lower excess risk than standard baselines such as gradient descent and ridge regression in many regimes, with results suggesting comparability to Lasso up to logarithmic factors when $d\le n$. Empirical analyses (head-masking, pruning, and P-probing) corroborate the mechanism, showing that keeping all heads in the first layer while selecting the most important head in subsequent layers preserves performance. The work provides a mechanistic explanation of multi-head attention in trained transformers and offers guidance for designing efficient, interpretable in-context learners.

Abstract

Despite the remarkable success of transformer-based models in various real-world tasks, their underlying mechanisms remain poorly understood. Recent studies have suggested that transformers can implement gradient descent as an in-context learner for linear regression problems and have developed various theoretical analyses accordingly. However, these works mostly focus on the expressive power of transformers by designing specific parameter constructions, lacking a comprehensive understanding of their inherent working mechanisms post-training. In this study, we consider a sparse linear regression problem and investigate how a trained multi-head transformer performs in-context learning. We experimentally discover that the utilization of multi-heads exhibits different patterns across layers: multiple heads are utilized and essential in the first layer, while usually only a single head is sufficient for subsequent layers. We provide a theoretical explanation for this observation: the first layer preprocesses the context data, and the following layers execute simple optimization steps based on the preprocessed context. Moreover, we demonstrate that such a preprocess-then-optimize algorithm can significantly outperform naive gradient descent and ridge regression algorithms. Further experimental results support our explanations. Our findings offer insights into the benefits of multi-head attention and contribute to understanding the more intricate mechanisms hidden within trained transformers.

How Transformers Utilize Multi-Head Attention in In-Context Learning? A Case Study on Sparse Linear Regression

TL;DR

The paper addresses how transformers perform in-context learning for sparse linear regression and discovers a two-phase mechanism: the first layer uses multiple heads to preprocess the context, while later layers perform gradient-descent-like optimization with a single head. It formalizes a preprocess-then-optimize algorithm that can be implemented by linear-attention transformers and proves it can achieve lower excess risk than standard baselines such as gradient descent and ridge regression in many regimes, with results suggesting comparability to Lasso up to logarithmic factors when . Empirical analyses (head-masking, pruning, and P-probing) corroborate the mechanism, showing that keeping all heads in the first layer while selecting the most important head in subsequent layers preserves performance. The work provides a mechanistic explanation of multi-head attention in trained transformers and offers guidance for designing efficient, interpretable in-context learners.

Abstract

Despite the remarkable success of transformer-based models in various real-world tasks, their underlying mechanisms remain poorly understood. Recent studies have suggested that transformers can implement gradient descent as an in-context learner for linear regression problems and have developed various theoretical analyses accordingly. However, these works mostly focus on the expressive power of transformers by designing specific parameter constructions, lacking a comprehensive understanding of their inherent working mechanisms post-training. In this study, we consider a sparse linear regression problem and investigate how a trained multi-head transformer performs in-context learning. We experimentally discover that the utilization of multi-heads exhibits different patterns across layers: multiple heads are utilized and essential in the first layer, while usually only a single head is sufficient for subsequent layers. We provide a theoretical explanation for this observation: the first layer preprocesses the context data, and the following layers execute simple optimization steps based on the preprocessed context. Moreover, we demonstrate that such a preprocess-then-optimize algorithm can significantly outperform naive gradient descent and ridge regression algorithms. Further experimental results support our explanations. Our findings offer insights into the benefits of multi-head attention and contribute to understanding the more intricate mechanisms hidden within trained transformers.
Paper Structure (43 sections, 16 theorems, 108 equations, 12 figures, 1 algorithm)

This paper contains 43 sections, 16 theorems, 108 equations, 12 figures, 1 algorithm.

Key Result

Theorem 5.1

Suppose ${\mathcal{S}}$ with ${\left| {\mathcal{S}} \right|}=s$ is selected such that each element is chosen with equal probability from the set ${\left\{ 1,2,\ldots,d \right\}}$ and $w^\star_i\sim \mathsf{U}\{-1/\sqrt{s}, 1/\sqrt{s}\}$ has a restricted uniform prior for $i\in{\mathcal{S}}$, ${\left with probability at least $1-\delta$. Besides, let $\widehat{\mathbf{w}}_{\lambda}$ be the ridge re

Figures (12)

  • Figure 1: Experimental insights into multi-head attention for in-context learning. (a): Overview of the experiments, including task, data, architecture, and our insights. (b): ICL with Varying Heads. (c): Heads Assessment. (d): Pruning and Probing.
  • Figure 2: Supporting experiments for our preprocess-then-optimize algorithm and theoretical analysis
  • Figure 3: ICL with varying heads, layers and noise levels
  • Figure 4: Head Assessment with varying heads, layers
  • Figure 5: Pruning and Probing, 3 layers
  • ...and 7 more figures

Theorems & Definitions (26)

  • Theorem 5.1
  • Proposition 2.1: Single-layer multi-head transformer implements Alg. \ref{['alg:DPPSLR']}
  • Proposition 2.2: Subsequent single-head transformer implements multi-step GD
  • Proposition 2.3: Restate of Proposition \ref{['prop:pre-process']}
  • proof
  • Proposition 2.4: Restate of Proposition \ref{['prop:tf_gd']}
  • proof
  • Lemma 4.1
  • Lemma 4.2
  • Lemma 4.3: Theorem 9 in bartlett2020benign
  • ...and 16 more