Quantitative Clustering in Mean-Field Transformer Models
Shi Chen, Zhengjiang Lin, Yury Polyanskiy, Philippe Rigollet
TL;DR
This work models token evolution in deep transformers as mean-field attention dynamics on the sphere and analyzes it as a Wasserstein gradient flow of an interaction energy ${\mathsf{E}}_{\beta}$. It establishes that non-Dirac stationary points are unstable and that, under small $\beta$ and suitable initial data (e.g., densities in $L^2$), the flow converges exponentially fast to a Dirac mass, with a Polyak–Łojasiewicz structure underpinning the rate. The authors extend the analysis to general transformer-attention mechanisms with learned matrices, deriving first/second variation formulas and showing that, under spectral assumptions, local maxima are global maxima at point masses; they prove a main convergence theorem with explicit rates and constants, and provide a detailed treatment of the perturbative (Kuramoto-like) regime and refined constant estimates. These results illuminate the long-time clustering behavior (synchronization) in mean-field transformer models and establish rigorous convergence mechanisms distinct from finite-particle intuitions, with implications for the interpretation of learned representations in deep architectures.
Abstract
The evolution of tokens through a deep transformer models can be modeled as an interacting particle system that has been shown to exhibit an asymptotic clustering behavior akin to the synchronization phenomenon in Kuramoto models. In this work, we investigate the long-time clustering of mean-field transformer models. More precisely, we establish exponential rates of contraction to a Dirac point mass for any suitably regular initialization under some assumptions on the parameters of transformer models, any suitably regular mean-field initialization synchronizes exponentially fast with some quantitative rates.
