Table of Contents
Fetching ...

An Analysis of Attention via the Lens of Exchangeability and Latent Variable Models

Yufeng Zhang, Boyi Liu, Qi Cai, Lingxiao Wang, Zhaoran Wang

TL;DR

The paper develops a principled, exchangeability-based latent-variable theory of attention in encoder-style transformers (e.g., BERT, ViT). It shows that a latent posterior over a per-data-point latent variable is a sufficient, minimal representation for downstream tasks, and that attention can perform this latent-posterior inference via kernel CME mechanisms. By linking CME attention with softmax attention in the large-sequence limit, the authors justify the use of attention as a nonparametric conditional density estimator and motivate multi-head architectures. They provide comprehensive excess-risk analyses—covering generalization, approximation, and optimization errors—and demonstrate how both supervised and self-supervised objectives can learn the desirable representation with length-invariant generalization. The SSL analysis introduces transfer- and conditioning-number concepts that quantify how well pretraining supports downstream tasks, offering a principled framework for understanding pretraining transfer in transformer models.

Abstract

With the attention mechanism, transformers achieve significant empirical successes. Despite the intuitive understanding that transformers perform relational inference over long sequences to produce desirable representations, we lack a rigorous theory on how the attention mechanism achieves it. In particular, several intriguing questions remain open: (a) What makes a desirable representation? (b) How does the attention mechanism infer the desirable representation within the forward pass? (c) How does a pretraining procedure learn to infer the desirable representation through the backward pass? We observe that, as is the case in BERT and ViT, input tokens are often exchangeable since they already include positional encodings. The notion of exchangeability induces a latent variable model that is invariant to input sizes, which enables our theoretical analysis. - To answer (a) on representation, we establish the existence of a sufficient and minimal representation of input tokens. In particular, such a representation instantiates the posterior distribution of the latent variable given input tokens, which plays a central role in predicting output labels and solving downstream tasks. - To answer (b) on inference, we prove that attention with the desired parameter infers the latent posterior up to an approximation error, which is decreasing in input sizes. In detail, we quantify how attention approximates the conditional mean of the value given the key, which characterizes how it performs relational inference over long sequences. - To answer (c) on learning, we prove that both supervised and self-supervised objectives allow empirical risk minimization to learn the desired parameter up to a generalization error, which is independent of input sizes. Particularly, in the self-supervised setting, we identify a condition number that is pivotal to solving downstream tasks.

An Analysis of Attention via the Lens of Exchangeability and Latent Variable Models

TL;DR

The paper develops a principled, exchangeability-based latent-variable theory of attention in encoder-style transformers (e.g., BERT, ViT). It shows that a latent posterior over a per-data-point latent variable is a sufficient, minimal representation for downstream tasks, and that attention can perform this latent-posterior inference via kernel CME mechanisms. By linking CME attention with softmax attention in the large-sequence limit, the authors justify the use of attention as a nonparametric conditional density estimator and motivate multi-head architectures. They provide comprehensive excess-risk analyses—covering generalization, approximation, and optimization errors—and demonstrate how both supervised and self-supervised objectives can learn the desirable representation with length-invariant generalization. The SSL analysis introduces transfer- and conditioning-number concepts that quantify how well pretraining supports downstream tasks, offering a principled framework for understanding pretraining transfer in transformer models.

Abstract

With the attention mechanism, transformers achieve significant empirical successes. Despite the intuitive understanding that transformers perform relational inference over long sequences to produce desirable representations, we lack a rigorous theory on how the attention mechanism achieves it. In particular, several intriguing questions remain open: (a) What makes a desirable representation? (b) How does the attention mechanism infer the desirable representation within the forward pass? (c) How does a pretraining procedure learn to infer the desirable representation through the backward pass? We observe that, as is the case in BERT and ViT, input tokens are often exchangeable since they already include positional encodings. The notion of exchangeability induces a latent variable model that is invariant to input sizes, which enables our theoretical analysis. - To answer (a) on representation, we establish the existence of a sufficient and minimal representation of input tokens. In particular, such a representation instantiates the posterior distribution of the latent variable given input tokens, which plays a central role in predicting output labels and solving downstream tasks. - To answer (b) on inference, we prove that attention with the desired parameter infers the latent posterior up to an approximation error, which is decreasing in input sizes. In detail, we quantify how attention approximates the conditional mean of the value given the key, which characterizes how it performs relational inference over long sequences. - To answer (c) on learning, we prove that both supervised and self-supervised objectives allow empirical risk minimization to learn the desired parameter up to a generalization error, which is independent of input sizes. Particularly, in the self-supervised setting, we identify a condition number that is pivotal to solving downstream tasks.
Paper Structure (12 sections, 7 theorems, 45 equations, 7 figures)

This paper contains 12 sections, 7 theorems, 45 equations, 7 figures.

Key Result

Proposition 3.1

Let $\{x^\ell\}_{\ell \in \mathbb{N}_+}$ be an exchangeable sequence. Then, there exists a latent variable $z$ such that for any sequence length $L \in \mathbb{N}_+$,

Figures (7)

  • Figure 1: The input sequence (the raw version without positional encodings) becomes exchangeable with positional encodings. In practice, the positional encoding is incorporated in an additive manner (instead of concatenation).
  • Figure 2: The forward pass for the prediction of the masked token $x^\ell$ and the target variable $y$. The prediction of $y$ takes two steps: i) the inference of the latent posterior $\mathbb{P}(z {\,|\,} X)$, and ii) the prediction of $y$ based on the generative distribution $\mathbb{P}(y{\,|\,} z)$ integrated with the latent posterior $\mathbb{P}(z{\,|\,} X)$.
  • Figure 3: Forward pass: within one data point $(X, y)$, we infer the latent posterior $\mathbb{P}_\theta(z {\,|\,} X)$ by \ref{['eq:latent-post-para']}. We predict $y_\dagger$ by $\widehat{y}$ in \ref{['eq:predict']}. Backward pass: across different data points in the dataset $\mathcal{D}_n$, we estimate the learnable parameter $\theta$ by \ref{['eq:para']}.
  • Figure 4: The forward and backward passes in transformers. Dotted arrows stand for forward passes (input$\rightarrow$latent$\rightarrow$target). Solid arrows stand for backward passes (training). Masks (grey tokens) are only used to illustrate the self-supervised setting (yellow box).
  • Figure 5: As shown in Propositions \ref{['prop:attn-cme']} and \ref{['prop:kernel_attn_story1']}, the softmax attention $\mathtt{attn}_{\mathtt{SM}}$ and the CME attention $\mathtt{attn}_{\mathtt{CME}}$ have the same limit $\mathbb{E}[ \mathcal{V} {\,|\,} \mathcal{K} = q]$ as $L\rightarrow \infty$.
  • ...and 2 more figures

Theorems & Definitions (7)

  • Proposition 3.1: de Finetti Representation Theorem de1937prevision
  • Lemma 3.2: Minimal Sufficiency of Latent Posterior
  • Proposition 4.1: CME Attention Converges to Kernel Conditional Mean Embedding
  • Proposition 4.2: Softmax Attention Converges to Kernel Conditional Mean Embedding
  • Theorem 5.3: Generalization Error
  • Theorem 5.5: Approximation Error
  • Proposition 5.6: Optimization Error