An Analysis of Attention via the Lens of Exchangeability and Latent Variable Models
Yufeng Zhang, Boyi Liu, Qi Cai, Lingxiao Wang, Zhaoran Wang
TL;DR
The paper develops a principled, exchangeability-based latent-variable theory of attention in encoder-style transformers (e.g., BERT, ViT). It shows that a latent posterior over a per-data-point latent variable is a sufficient, minimal representation for downstream tasks, and that attention can perform this latent-posterior inference via kernel CME mechanisms. By linking CME attention with softmax attention in the large-sequence limit, the authors justify the use of attention as a nonparametric conditional density estimator and motivate multi-head architectures. They provide comprehensive excess-risk analyses—covering generalization, approximation, and optimization errors—and demonstrate how both supervised and self-supervised objectives can learn the desirable representation with length-invariant generalization. The SSL analysis introduces transfer- and conditioning-number concepts that quantify how well pretraining supports downstream tasks, offering a principled framework for understanding pretraining transfer in transformer models.
Abstract
With the attention mechanism, transformers achieve significant empirical successes. Despite the intuitive understanding that transformers perform relational inference over long sequences to produce desirable representations, we lack a rigorous theory on how the attention mechanism achieves it. In particular, several intriguing questions remain open: (a) What makes a desirable representation? (b) How does the attention mechanism infer the desirable representation within the forward pass? (c) How does a pretraining procedure learn to infer the desirable representation through the backward pass? We observe that, as is the case in BERT and ViT, input tokens are often exchangeable since they already include positional encodings. The notion of exchangeability induces a latent variable model that is invariant to input sizes, which enables our theoretical analysis. - To answer (a) on representation, we establish the existence of a sufficient and minimal representation of input tokens. In particular, such a representation instantiates the posterior distribution of the latent variable given input tokens, which plays a central role in predicting output labels and solving downstream tasks. - To answer (b) on inference, we prove that attention with the desired parameter infers the latent posterior up to an approximation error, which is decreasing in input sizes. In detail, we quantify how attention approximates the conditional mean of the value given the key, which characterizes how it performs relational inference over long sequences. - To answer (c) on learning, we prove that both supervised and self-supervised objectives allow empirical risk minimization to learn the desired parameter up to a generalization error, which is independent of input sizes. Particularly, in the self-supervised setting, we identify a condition number that is pivotal to solving downstream tasks.
