Table of Contents
Fetching ...

Superiority of Multi-Head Attention in In-Context Linear Regression

Yingqian Cui, Jie Ren, Pengfei He, Jiliang Tang, Yue Xing

TL;DR

This work theoretically and empirically establishes that multi-head softmax attention in transformers yields superior in-context learning for linear regression tasks compared with single-head attention, particularly when embedding dimension is large. It derives exact asymptotic risk expressions, analyzes multiple data scenarios (prior knowledge, noisy labels, correlated features, local context), and provides practical design guidance on embedding dimension versus head count. The results are supported by simulations and experiments showing multi-head attention achieves smaller prediction risk constants and better kernel flexibility. The findings offer concrete prescriptions for selecting attention mechanisms in ICL and motivate future work on finite-prompt and non-linear extensions.

Abstract

We present a theoretical analysis of the performance of transformer with softmax attention in in-context learning with linear regression tasks. While the existing literature predominantly focuses on the convergence of transformers with single-/multi-head attention, our research centers on comparing their performance. We conduct an exact theoretical analysis to demonstrate that multi-head attention with a substantial embedding dimension performs better than single-head attention. When the number of in-context examples D increases, the prediction loss using single-/multi-head attention is in O(1/D), and the one for multi-head attention has a smaller multiplicative constant. In addition to the simplest data distribution setting, we consider more scenarios, e.g., noisy labels, local examples, correlated features, and prior knowledge. We observe that, in general, multi-head attention is preferred over single-head attention. Our results verify the effectiveness of the design of multi-head attention in the transformer architecture.

Superiority of Multi-Head Attention in In-Context Linear Regression

TL;DR

This work theoretically and empirically establishes that multi-head softmax attention in transformers yields superior in-context learning for linear regression tasks compared with single-head attention, particularly when embedding dimension is large. It derives exact asymptotic risk expressions, analyzes multiple data scenarios (prior knowledge, noisy labels, correlated features, local context), and provides practical design guidance on embedding dimension versus head count. The results are supported by simulations and experiments showing multi-head attention achieves smaller prediction risk constants and better kernel flexibility. The findings offer concrete prescriptions for selecting attention mechanisms in ICL and motivate future work on finite-prompt and non-linear extensions.

Abstract

We present a theoretical analysis of the performance of transformer with softmax attention in in-context learning with linear regression tasks. While the existing literature predominantly focuses on the convergence of transformers with single-/multi-head attention, our research centers on comparing their performance. We conduct an exact theoretical analysis to demonstrate that multi-head attention with a substantial embedding dimension performs better than single-head attention. When the number of in-context examples D increases, the prediction loss using single-/multi-head attention is in O(1/D), and the one for multi-head attention has a smaller multiplicative constant. In addition to the simplest data distribution setting, we consider more scenarios, e.g., noisy labels, local examples, correlated features, and prior knowledge. We observe that, in general, multi-head attention is preferred over single-head attention. Our results verify the effectiveness of the design of multi-head attention in the transformer architecture.
Paper Structure (37 sections, 8 theorems, 133 equations, 21 figures, 2 tables)

This paper contains 37 sections, 8 theorems, 133 equations, 21 figures, 2 tables.

Key Result

Theorem 4.1

Under Assumption assumption:data, assumption:nn, assume (1) there is infinite training prompts, (2) $(W_{out}W^V)_{d+1,:}=(0,\ldots,0,v)$, and (3) $(W^K)^{\top}W_Q$ is in a format of then when $D\rightarrow\infty$, the loss value is and the optimal solution satisfies that $\|vA-I_d\|_F^2=O(1/D)$, and $\|vb\|^2=O(1/D)$. In addition, when taking $A=I_d/v$ and $b=0$, Denoting the optimal solution

Figures (21)

  • Figure 1: ICL performance of single-head attention with $(A,b)=(I_d/v,0)$ and $D=1000$.
  • Figure 2: ICL performance of single-head attention with $(A,b)=(I_d/v,0)$ and $d=5$.
  • Figure 3: ICL performance of multi-head attention with $(m,n)=(2,1)$, $(A_1,A_2,b_1,b_2)=\left((c/v)I_d,((2c-1)/v)I_d,0,0\right)$, and $(d,D)=(5,1000)$.
  • Figure 4: A comparison between single-head and multi-head with the input embedding dimension $p=64$.
  • Figure 5: An illustration of the matrix $(W^K)^{\top}W^Q$ for the no read-in case. It is expected to be some kinds of $\alpha I_d$. 4 of 10 trials are like this.
  • ...and 16 more figures

Theorems & Definitions (17)

  • Theorem 4.1: Optimal Solution of Single-Head Attention
  • Remark 4.1
  • Theorem 4.2: Multi-head Attention is Better
  • Proposition 4.1
  • Theorem 5.1
  • Theorem 5.2
  • Theorem 5.3
  • Theorem 5.4
  • Theorem 5.5
  • proof : Proof of Theorem \ref{['thm:optimal']}
  • ...and 7 more