Table of Contents
Fetching ...

Inductive Biases and Variable Creation in Self-Attention Mechanisms

Benjamin L. Edelman, Surbhi Goel, Sham Kakade, Cyril Zhang

TL;DR

This work provides a rigorous statistical analysis of self-attention and Transformer architectures, revealing a sparse-variable creation inductive bias: with bounded weight norms, a single self-attention head can realize sparse input dependencies and generalize with a sample complexity that scales only logarithmically with context length. The authors develop a novel covering-number-based capacity bound for attention modules and show matching representation results that sparse functions can be encoded efficiently. They extend these results to multi-layer Transformers and discuss practical implications of positional embeddings, residuals, and multi-head attention. Through synthetic experiments, the paper confirms the predicted $\log T$ scaling in learning sparse Boolean functions and highlights interesting phenomena like parity learning under i.i.d. sampling. Overall, the work provides a principled bridge between the practical success of attention mechanisms and their theoretical capacity to represent sparse, long-range dependencies.

Abstract

Self-attention, an architectural motif designed to model long-range interactions in sequential data, has driven numerous recent breakthroughs in natural language processing and beyond. This work provides a theoretical analysis of the inductive biases of self-attention modules. Our focus is to rigorously establish which functions and long-range dependencies self-attention blocks prefer to represent. Our main result shows that bounded-norm Transformer networks "create sparse variables": a single self-attention head can represent a sparse function of the input sequence, with sample complexity scaling only logarithmically with the context length. To support our analysis, we present synthetic experiments to probe the sample complexity of learning sparse Boolean functions with Transformers.

Inductive Biases and Variable Creation in Self-Attention Mechanisms

TL;DR

This work provides a rigorous statistical analysis of self-attention and Transformer architectures, revealing a sparse-variable creation inductive bias: with bounded weight norms, a single self-attention head can realize sparse input dependencies and generalize with a sample complexity that scales only logarithmically with context length. The authors develop a novel covering-number-based capacity bound for attention modules and show matching representation results that sparse functions can be encoded efficiently. They extend these results to multi-layer Transformers and discuss practical implications of positional embeddings, residuals, and multi-head attention. Through synthetic experiments, the paper confirms the predicted scaling in learning sparse Boolean functions and highlights interesting phenomena like parity learning under i.i.d. sampling. Overall, the work provides a principled bridge between the practical success of attention mechanisms and their theoretical capacity to represent sparse, long-range dependencies.

Abstract

Self-attention, an architectural motif designed to model long-range interactions in sequential data, has driven numerous recent breakthroughs in natural language processing and beyond. This work provides a theoretical analysis of the inductive biases of self-attention modules. Our focus is to rigorously establish which functions and long-range dependencies self-attention blocks prefer to represent. Our main result shows that bounded-norm Transformer networks "create sparse variables": a single self-attention head can represent a sparse function of the input sequence, with sample complexity scaling only logarithmically with the context length. To support our analysis, we present synthetic experiments to probe the sample complexity of learning sparse Boolean functions with Transformers.

Paper Structure

This paper contains 77 sections, 32 theorems, 178 equations, 4 figures.

Key Result

Lemma 2.2

Suppose $\mathcal{F}$ is a class of bounded functions, and $\log \mathcal{N}_\infty(\mathcal{F}; \varepsilon; x^{(1)}, \dots, x^{(m)}) \le C_\mathcal{F}/{\varepsilon}^2$ for all $x^{(1)}, \ldots, x^{(m)} \in \mathcal{X}^m$. Then for any $\delta > 0$, with probability at least $1 - \delta$, simultane

Figures (4)

  • Figure 1: Diagrams of attention modules $f_{\mathsf{tf{\text{-}}head}}, f_{\mathsf{tf{\text{-}}layer}}, f_{\mathsf{tf{\text{-}}scalar}}$ described in Section \ref{['subsec:prelims-attn']}: alignment scores (grey edges) determine normalized attention weights (blue), which are used to mix the inputs $x_{1:T}$. Left: Attention with a general context $z$. Center: Self-attention layer, where both the input and the context come from $x_{1:T}$. Right: Auxiliary $\textup{[CLS]}$ token to extract a single scalar from a self-attention layer, providing a real-valued function class for classification or regression tasks.
  • Figure 2: Main experimental finding: the sample complexity of learning a $3$-sparse $\mathsf{AND}$ function of $T$ input bits with Transformers. For each $T$, we measure the smallest sample size $m$ necessary to reach $100\%$ validation accuracy on $\geq 80\%$ of random trials. We find that this threshold scales logarithmically with $T$.
  • Figure 3: Additional visualizations for the sparse function learning experiments. Left: Examples of validation accuracy curves on the same problem instance ($T = 300$), with sample sizes above ($m = 200$) and below ($m = 50$) the threshold ($\approx 70$ from Figure \ref{['fig:sparse-and-scaling-laws']}). Training accuracy goes to $100\%$ in both cases, but the Transformer overfits (orange curves) when $m$ is too small. Right: Per-example attention weights for a successfully trained model ($T = 50$, $m = 300$, ${\mathcal{I}} = \{5, 20, 30\}$). The input-dependent attention weights approximately zero out the irrelevant bits.
  • Figure 4: A curious empirical finding: Transformers can learn sparse parities. Loss curves (across 10 random seeds for initialization and SGD samples) are shown for this setup with $s=3, T \in \{10,15\}$, exhibiting phase transitions from random guessing to $100\%$ accuracy. See Appendix \ref{['subsec:appendix-sparse-parity']} for details.

Theorems & Definitions (61)

  • Definition 2.1: Covering number
  • Lemma 2.2: Generalization bound via covering number; informal
  • Definition 3.1: Attention head
  • Definition 3.2: Transformer layer
  • Theorem 4.2: Attention head capacity
  • Lemma 4.3: $\ell_\infty$-Lipschitzness of $f_\mathsf{head}$
  • Proposition 4.4
  • Corollary 4.5
  • Lemma 4.6
  • Theorem 4.7: Theorem \ref{['thm:deeptf_full']} (simplified)
  • ...and 51 more