Table of Contents
Fetching ...

Dissecting the Interplay of Attention Paths in a Statistical Mechanics Theory of Transformers

Lorenzo Tiberi, Francesca Mignacco, Kazuki Irie, Haim Sompolinsky

TL;DR

This paper develops a finite-width thermodynamic Bayesian theory for a Transformer-like multi-head self-attention model, deriving exact predictor statistics as a weighted sum of path-path kernels. A key finding is the task-relevant kernel combination: in the finite-width regime, cross-path kernels are learned and weighted to align with task labels, enhancing generalization beyond the infinite-width Gaussian-process limit. The authors validate the theory with Hamiltonian Monte Carlo on synthetic (Hidden Markov) and real (one-shot image) tasks and show how the order parameter reveals which attention paths to emphasize or suppress, enabling effective head pruning with minimal performance loss. The approach provides interpretable connections between learned weights and kernel alignment, offering guidance for resource-efficient transformer implementations and potential extensions to gradient-descent trained networks.

Abstract

Despite the remarkable empirical performance of Transformers, their theoretical understanding remains elusive. Here, we consider a deep multi-head self-attention network, that is closely related to Transformers yet analytically tractable. We develop a statistical mechanics theory of Bayesian learning in this model, deriving exact equations for the network's predictor statistics under the finite-width thermodynamic limit, i.e., $N,P\rightarrow\infty$, $P/N=\mathcal{O}(1)$, where $N$ is the network width and $P$ is the number of training examples. Our theory shows that the predictor statistics are expressed as a sum of independent kernels, each one pairing different 'attention paths', defined as information pathways through different attention heads across layers. The kernels are weighted according to a 'task-relevant kernel combination' mechanism that aligns the total kernel with the task labels. As a consequence, this interplay between attention paths enhances generalization performance. Experiments confirm our findings on both synthetic and real-world sequence classification tasks. Finally, our theory explicitly relates the kernel combination mechanism to properties of the learned weights, allowing for a qualitative transfer of its insights to models trained via gradient descent. As an illustration, we demonstrate an efficient size reduction of the network, by pruning those attention heads that are deemed less relevant by our theory.

Dissecting the Interplay of Attention Paths in a Statistical Mechanics Theory of Transformers

TL;DR

This paper develops a finite-width thermodynamic Bayesian theory for a Transformer-like multi-head self-attention model, deriving exact predictor statistics as a weighted sum of path-path kernels. A key finding is the task-relevant kernel combination: in the finite-width regime, cross-path kernels are learned and weighted to align with task labels, enhancing generalization beyond the infinite-width Gaussian-process limit. The authors validate the theory with Hamiltonian Monte Carlo on synthetic (Hidden Markov) and real (one-shot image) tasks and show how the order parameter reveals which attention paths to emphasize or suppress, enabling effective head pruning with minimal performance loss. The approach provides interpretable connections between learned weights and kernel alignment, offering guidance for resource-efficient transformer implementations and potential extensions to gradient-descent trained networks.

Abstract

Despite the remarkable empirical performance of Transformers, their theoretical understanding remains elusive. Here, we consider a deep multi-head self-attention network, that is closely related to Transformers yet analytically tractable. We develop a statistical mechanics theory of Bayesian learning in this model, deriving exact equations for the network's predictor statistics under the finite-width thermodynamic limit, i.e., , , where is the network width and is the number of training examples. Our theory shows that the predictor statistics are expressed as a sum of independent kernels, each one pairing different 'attention paths', defined as information pathways through different attention heads across layers. The kernels are weighted according to a 'task-relevant kernel combination' mechanism that aligns the total kernel with the task labels. As a consequence, this interplay between attention paths enhances generalization performance. Experiments confirm our findings on both synthetic and real-world sequence classification tasks. Finally, our theory explicitly relates the kernel combination mechanism to properties of the learned weights, allowing for a qualitative transfer of its insights to models trained via gradient descent. As an illustration, we demonstrate an efficient size reduction of the network, by pruning those attention heads that are deemed less relevant by our theory.
Paper Structure (51 sections, 113 equations, 7 figures)

This paper contains 51 sections, 113 equations, 7 figures.

Figures (7)

  • Figure 1: Scheme of the model and theory(a) Scheme of the model in terms of attention paths. (b) The order parameter assigns to each pair of paths a weight, given by the overlap between the corresponding effective weights. (c) Alignment of the kernel PCs with the vector of task labels $Y$, in the finite-width (FW) vs GP regimes. (d) Kernel as the weighted sum of many path-path kernels. Task-relevant kernel combination occurs in the finite-width regime (FW), but not in the GP limit, in which cross-path kernels are discarded, and same-path kernels are equally weighted. The result is an improved kernel-task alignment in the finite-width regime (shown in (c)), enhancing generalization.
  • Figure 2: Hidden Markov chain task.(a) Illustration of the task. (b) Schematics of the network and its attention paths. (c) Top: Classification accuracy for varying $N$ (theory: blue crosses, joined by blue line; samples: black dots). Red lines: GP limit for a network consisting of all paths (solid), the good path (dashed), and the good and denoising paths (dotted). Bottom: Matrix elements of $U$, for varying $N$. The matrix indices are labeled with the corresponding path name, according to the legend in (b). (d) Normalized overlap, or cosine similarity, between the PCs of the kernel $K$ and the vector of task labels $Y$ ($N=10$: blue; GP limit: orange). PCs are ranked by their eigenvalues, from largest to smallest. Only the first $30$ PCs are shown. (e) Same as (c), but for increased $\sigma_{\perp}=5$ and a network consisting of only the good and denoising paths.
  • Figure 3: One-shot image classification task.(a) Scheme of the task. (b) Classification accuracy in the GP limit (red line) and the finite-width regime (FW) for varying $N$ (theory: blue crosses, joined by blue line; samples: black dots). (c) Matrix elements of $U$. The "theory" and "sampled" $U$s are for $N=10$. The matrix indices are labeled with the path index $\pi=(h_{1},h_{2})$. (d) Kernel PCs' overlap with the task, in the GP limit and in the finite-width regime for $N=10$. Only the first $50$ PCs are shown. (e) Head score (blue) and performance drop (red) after pruning the head, for the model trained with gradient descent. (f) Classification accuracy of the model trained with gradient descent, after pruning a growing number of heads, in order of their head score.
  • Figure 4: Schematic representation of the architecture under consideration.
  • Figure 5: Hidden Markov chain task. (a) Schematics of the network and its attention paths. (b) Kernels. Same-path kernels associated with the 4 paths shown in (a), total kernel in the GP limit, and total kernel for $N=10$ in the renormalized regime (RN). Examples on both the $x$ and $y$ axes are ordered by class (the first half correspond to the first class, the second half correspond to the second class). (c) Classification accuracy for a network consisting of only the good and denoising paths, $\sigma_{\parallel}=1$ and $\sigma_{\perp}=5$. The figure is analogous to Fig. \ref{['fig:markov']}(f), with the difference that we have replaced the random head involved in the denoising path with a uniform attention head.
  • ...and 2 more figures