Dissecting the Interplay of Attention Paths in a Statistical Mechanics Theory of Transformers
Lorenzo Tiberi, Francesca Mignacco, Kazuki Irie, Haim Sompolinsky
TL;DR
This paper develops a finite-width thermodynamic Bayesian theory for a Transformer-like multi-head self-attention model, deriving exact predictor statistics as a weighted sum of path-path kernels. A key finding is the task-relevant kernel combination: in the finite-width regime, cross-path kernels are learned and weighted to align with task labels, enhancing generalization beyond the infinite-width Gaussian-process limit. The authors validate the theory with Hamiltonian Monte Carlo on synthetic (Hidden Markov) and real (one-shot image) tasks and show how the order parameter reveals which attention paths to emphasize or suppress, enabling effective head pruning with minimal performance loss. The approach provides interpretable connections between learned weights and kernel alignment, offering guidance for resource-efficient transformer implementations and potential extensions to gradient-descent trained networks.
Abstract
Despite the remarkable empirical performance of Transformers, their theoretical understanding remains elusive. Here, we consider a deep multi-head self-attention network, that is closely related to Transformers yet analytically tractable. We develop a statistical mechanics theory of Bayesian learning in this model, deriving exact equations for the network's predictor statistics under the finite-width thermodynamic limit, i.e., $N,P\rightarrow\infty$, $P/N=\mathcal{O}(1)$, where $N$ is the network width and $P$ is the number of training examples. Our theory shows that the predictor statistics are expressed as a sum of independent kernels, each one pairing different 'attention paths', defined as information pathways through different attention heads across layers. The kernels are weighted according to a 'task-relevant kernel combination' mechanism that aligns the total kernel with the task labels. As a consequence, this interplay between attention paths enhances generalization performance. Experiments confirm our findings on both synthetic and real-world sequence classification tasks. Finally, our theory explicitly relates the kernel combination mechanism to properties of the learned weights, allowing for a qualitative transfer of its insights to models trained via gradient descent. As an illustration, we demonstrate an efficient size reduction of the network, by pruning those attention heads that are deemed less relevant by our theory.
