Table of Contents
Fetching ...

Attention layers provably solve single-location regression

Pierre Marion, Raphaël Berthier, Gérard Biau, Claire Boyer

TL;DR

This work tackles the theoretical understanding of attention mechanisms in the presence of token sparsity by formulating single-location regression, where the output depends on one latent token among a sequence. It introduces a dedicated predictor that mirrors a simplified, nonlinear self-attention layer and proves its asymptotic Bayes optimality, while also analyzing the non-convex training dynamics via projected gradient descent. The results show that the oracle predictor attains Bayes-optimal performance in a high-dimensional regime, whereas linear predictors fail when the relevant location is latent, highlighting the distinct advantages of attention-like architectures. The findings illuminate how Transformers can store and utilize sparse token information through internal linear representations, with implications for interpretability and extensions to more complex sparse-sequence tasks in NLP and time-series analysis.

Abstract

Attention-based models, such as Transformer, excel across various tasks but lack a comprehensive theoretical understanding, especially regarding token-wise sparsity and internal linear representations. To address this gap, we introduce the single-location regression task, where only one token in a sequence determines the output, and its position is a latent random variable, retrievable via a linear projection of the input. To solve this task, we propose a dedicated predictor, which turns out to be a simplified version of a non-linear self-attention layer. We study its theoretical properties, by showing its asymptotic Bayes optimality and analyzing its training dynamics. In particular, despite the non-convex nature of the problem, the predictor effectively learns the underlying structure. This work highlights the capacity of attention mechanisms to handle sparse token information and internal linear structures.

Attention layers provably solve single-location regression

TL;DR

This work tackles the theoretical understanding of attention mechanisms in the presence of token sparsity by formulating single-location regression, where the output depends on one latent token among a sequence. It introduces a dedicated predictor that mirrors a simplified, nonlinear self-attention layer and proves its asymptotic Bayes optimality, while also analyzing the non-convex training dynamics via projected gradient descent. The results show that the oracle predictor attains Bayes-optimal performance in a high-dimensional regime, whereas linear predictors fail when the relevant location is latent, highlighting the distinct advantages of attention-like architectures. The findings illuminate how Transformers can store and utilize sparse token information through internal linear representations, with implications for interpretability and extensions to more complex sparse-sequence tasks in NLP and time-series analysis.

Abstract

Attention-based models, such as Transformer, excel across various tasks but lack a comprehensive theoretical understanding, especially regarding token-wise sparsity and internal linear representations. To address this gap, we introduce the single-location regression task, where only one token in a sequence determines the output, and its position is a latent random variable, retrievable via a linear projection of the input. To solve this task, we propose a dedicated predictor, which turns out to be a simplified version of a non-linear self-attention layer. We study its theoretical properties, by showing its asymptotic Bayes optimality and analyzing its training dynamics. In particular, despite the non-convex nature of the problem, the predictor effectively learns the underlying structure. This work highlights the capacity of attention mechanisms to handle sparse token information and internal linear structures.
Paper Structure (64 sections, 18 theorems, 139 equations, 9 figures, 2 tables)

This paper contains 64 sections, 18 theorems, 139 equations, 9 figures, 2 tables.

Key Result

Theorem 1

There exists a function $\mathcal{R}_\lambda^<: {\mathbb{R}}^5 \to {\mathbb{R}}$ such that, for any $(k,v) \in (\mathbb{S}^{d-1})^2$, where $\kappa:= k^\top k^\star$, $\nu := v^\top v^\star$, $\theta := v^\top k^\star$, $\eta := k^\top v^\star$, and $\rho := k^\top v$. A closed-form expression of $\mathcal{R}_\lambda^<$ is given in Appendix proof:risk_on_sphere. In particular, where, for $t,\gam

Figures (9)

  • Figure 1: A simple sentiment analysis task with synthetic data, which exemplifies (a) token-wise sparsity and (b) internal linear representations. We refer to Appendix \ref{['app:experimental-details']} for details on the experiment.
  • Figure 2: Modeling of an NLP task within our statistical setting \ref{['pb:learning_pb']}. The token embeddings $X_1, \dots, X_L$ are constructed by adding the embeddings of each word and a positional encoding. For illustration purposes, we assume that each token corresponds to a word, and that the positional encoding solely depends on the part of the sentence (before or after the question mark), which differs from usual practice. Then, let the direction $k^\star$ encode both the notion of sentiment and the position in the second part of the sentence. Thus only the last token of the sentence is aligned (positively) with $k^\star$, and we have $J_0=L$. As for $v^\star$, it encodes whether the word is associated with a positive or negative sentiment. Note that several tokens are positively or negatively aligned with $v^\star$, but the output $Y$ only depends on the token $J_0$. This illustrates the interest of having two latent directions $k^\star$ and $v^\star$, one that filters the informative token and one that aligns with the output $Y$.
  • Figure 3: Convergence of PGD to the oracle parameters. Left: Excess risk as a function of the number of steps. Middle left: Alignment $|\kappa| = |k^\top k^\star|$ and $|\nu| = |v^\top v^\star|$ with the oracle parameters. Middle right: Trajectories of $\kappa$ and $\nu$ in two repetitions of the experiments. Each repetition corresponds to a color, the trajectory starts in the middle and ends at a corner of the plot. Right: Distance to the invariant manifold ${\mathcal{M}}$. In all plots except the middle right ones, the experiment is repeated $30$ times with independent random initializations, and $95\%$ percentile intervals are plotted (but are not visible when the variance is too small). Parameters are $d=400$, $L=10$, $\gamma = \sqrt{1/2}$, and (a) $\lambda_t = 1/(1 + 10^{-4}t)$, (b) $\lambda_t = 0.1$. More details are given in Appendix \ref{['app:experimental-details']}.
  • Figure 4: Convergence of online stochastic PGD to the oracle parameters from a random initialization on $(\mathbb{S}^{d-1})^2$. Left: Excess risk as a function of the number of steps. Middle left: Alignment $|\kappa| = |k^\top k^\star|$ and $|\nu| = |v^\top v^\star|$ with the oracle parameters. Middle right: Trajectories of $\kappa$ and $\nu$ in two repetitions of the experiment. Each repetition corresponds to a color, the trajectory starts in the middle and ends at a corner of the plot. Right: Distance to the invariant manifold ${\mathcal{M}}$. In all plots except the middle right one, the experiment is repeated $30$ times with independent random initializations, and $95\%$ percentile intervals are plotted. Parameters are $d=80$, $L=10$, $\gamma = \sqrt{1/2}$, $\lambda_t = 2/(1 + 10^{-4}t)$, and a batch size of $5$. More details are given in Appendix \ref{['app:experimental-details']}.
  • Figure 5: Dynamics in $(\kappa, \nu)$ on the manifold $\mathcal{M}$. In (a), the fixed points of the dynamics are represented; the minimizers, saddle point, and maximizers are respectively depicted in yellow, blue and red. In (b), the vector field $(\kappa,\nu) \mapsto -(\partial_\kappa \mathcal{R}^< (\kappa,\nu)(1-\kappa^2), \partial_\nu \mathcal{R}^< (\kappa,\nu)(1-\nu^2))$ is displayed (the colormap corresponds to the magnitude of the vector field).
  • ...and 4 more figures

Theorems & Definitions (25)

  • Theorem 1
  • Corollary 2
  • Proposition 3
  • Definition 1: PGD
  • Lemma 4
  • Theorem 5
  • Lemma 6
  • Lemma 7
  • Proposition 8
  • Proposition 9
  • ...and 15 more