Table of Contents
Fetching ...

Mapping of attention mechanisms to a generalized Potts model

Riccardo Rende, Federica Gerace, Alessandro Laio, Sebastian Goldt

TL;DR

Analytically, it is shown that if one decouples the treatment of word positions and embeddings, a single layer of self-attention learns the conditionals of a generalized Potts model with interactions between sites and Potts colors.

Abstract

Transformers are neural networks that revolutionized natural language processing and machine learning. They process sequences of inputs, like words, using a mechanism called self-attention, which is trained via masked language modeling (MLM). In MLM, a word is randomly masked in an input sequence, and the network is trained to predict the missing word. Despite the practical success of transformers, it remains unclear what type of data distribution self-attention can learn efficiently. Here, we show analytically that if one decouples the treatment of word positions and embeddings, a single layer of self-attention learns the conditionals of a generalized Potts model with interactions between sites and Potts colors. Moreover, we show that training this neural network is exactly equivalent to solving the inverse Potts problem by the so-called pseudo-likelihood method, well known in statistical physics. Using this mapping, we compute the generalization error of self-attention in a model scenario analytically using the replica method.

Mapping of attention mechanisms to a generalized Potts model

TL;DR

Analytically, it is shown that if one decouples the treatment of word positions and embeddings, a single layer of self-attention learns the conditionals of a generalized Potts model with interactions between sites and Potts colors.

Abstract

Transformers are neural networks that revolutionized natural language processing and machine learning. They process sequences of inputs, like words, using a mechanism called self-attention, which is trained via masked language modeling (MLM). In MLM, a word is randomly masked in an input sequence, and the network is trained to predict the missing word. Despite the practical success of transformers, it remains unclear what type of data distribution self-attention can learn efficiently. Here, we show analytically that if one decouples the treatment of word positions and embeddings, a single layer of self-attention learns the conditionals of a generalized Potts model with interactions between sites and Potts colors. Moreover, we show that training this neural network is exactly equivalent to solving the inverse Potts problem by the so-called pseudo-likelihood method, well known in statistical physics. Using this mapping, we compute the generalization error of self-attention in a model scenario analytically using the replica method.
Paper Structure (10 sections, 43 equations, 3 figures)

This paper contains 10 sections, 43 equations, 3 figures.

Figures (3)

  • Figure 1: Masked language modeling (MLM) with a single layer of self-attention. The goal of MLM is to predict the masked word in a given sentence. Self-attention first maps words into representations ${\mathbf{e}}_j + {\mathbf{p}}_j$, where ${\mathbf{e}}_j$ are embedding vectors representing words, and ${\mathbf{p}}_j$ encode their positions. For a given masked word, the associated attention vector ${\mathbf{h}}_{k}$ is computed as a linear combination of the values ${\mathbf{v}}_j = V({\mathbf{e}}_j + {\mathbf{p}}_j)$ of all other tokens, weighted by the attention weights $A_{kj}$. In vanilla self-attention, values and attention weights depend on embeddings and positional vectors, while in factored attention, attention weights depend only on positions, and values only on the embeddings. By identifying the attention weights $A$ with the interaction matrix $J$ of a Potts model \ref{['eq:potts']}, the value matrix $V$ with the color similarity matrix U and the embedding vectors with the one-hot spins, we get a learning model identical to a Potts model.
  • Figure 2: A single layer of factored self-attention learns the generalized Potts model efficiently.(a) Test loss \ref{['eq:eg']} for factored self-attention and for vanilla transformers with one and three layers during training with stochastic gradient descent. The optimal generalization loss is shown as a black dashed line. (b) Interaction matrix $J$ of the generative Potts Model \ref{['eq:potts']} compared to the attention maps learned by transformers with vanilla and factored self-attention. For the three-layer transformer, the attention map was obtained by averaging the maps of the last two layers. (c) Reconstruction error of the interaction ${(J - A)}^2$ as a function of the number of epochs for all considered architectures. (d) Test loss as a function of perturbation level $a$. Decoupling the treatment between positions and colors by decreasing $a$ decreases the test loss. Parameters: sequence length $L=20$, vocabulary size $C=20$, embedding dimension $d=20$, $M=3000$ data points.
  • Figure 3: The interpolation peak of factored attention in theory and practice.(Left) A replica analysis predicts the test loss exactly. Test loss of a single layer of factored self-attention as a function of the number of samples per input dimension, as computed using replica theory (solid line). The blue points represent the outcome of numerical minimisation of the square loss \ref{['eq:simplified_optimization']}, averaged over 30 realisations, and show perfect agreement with the theory. Error bars are smaller than point size. (Right) Same plot for a single layer of factored self-attention in the setting of \ref{['fig:factored_vs_vanilla']} ($L=C=20$), showing the same qualitative behaviour. The simulations are averaged over $n=30$ different realisations.