Table of Contents
Fetching ...

Quantitative Clustering in Mean-Field Transformer Models

Shi Chen, Zhengjiang Lin, Yury Polyanskiy, Philippe Rigollet

TL;DR

This work models token evolution in deep transformers as mean-field attention dynamics on the sphere and analyzes it as a Wasserstein gradient flow of an interaction energy ${\mathsf{E}}_{\beta}$. It establishes that non-Dirac stationary points are unstable and that, under small $\beta$ and suitable initial data (e.g., densities in $L^2$), the flow converges exponentially fast to a Dirac mass, with a Polyak–Łojasiewicz structure underpinning the rate. The authors extend the analysis to general transformer-attention mechanisms with learned matrices, deriving first/second variation formulas and showing that, under spectral assumptions, local maxima are global maxima at point masses; they prove a main convergence theorem with explicit rates and constants, and provide a detailed treatment of the perturbative (Kuramoto-like) regime and refined constant estimates. These results illuminate the long-time clustering behavior (synchronization) in mean-field transformer models and establish rigorous convergence mechanisms distinct from finite-particle intuitions, with implications for the interpretation of learned representations in deep architectures.

Abstract

The evolution of tokens through a deep transformer models can be modeled as an interacting particle system that has been shown to exhibit an asymptotic clustering behavior akin to the synchronization phenomenon in Kuramoto models. In this work, we investigate the long-time clustering of mean-field transformer models. More precisely, we establish exponential rates of contraction to a Dirac point mass for any suitably regular initialization under some assumptions on the parameters of transformer models, any suitably regular mean-field initialization synchronizes exponentially fast with some quantitative rates.

Quantitative Clustering in Mean-Field Transformer Models

TL;DR

This work models token evolution in deep transformers as mean-field attention dynamics on the sphere and analyzes it as a Wasserstein gradient flow of an interaction energy . It establishes that non-Dirac stationary points are unstable and that, under small and suitable initial data (e.g., densities in ), the flow converges exponentially fast to a Dirac mass, with a Polyak–Łojasiewicz structure underpinning the rate. The authors extend the analysis to general transformer-attention mechanisms with learned matrices, deriving first/second variation formulas and showing that, under spectral assumptions, local maxima are global maxima at point masses; they prove a main convergence theorem with explicit rates and constants, and provide a detailed treatment of the perturbative (Kuramoto-like) regime and refined constant estimates. These results illuminate the long-time clustering behavior (synchronization) in mean-field transformer models and establish rigorous convergence mechanisms distinct from finite-particle intuitions, with implications for the interpretation of learned representations in deep architectures.

Abstract

The evolution of tokens through a deep transformer models can be modeled as an interacting particle system that has been shown to exhibit an asymptotic clustering behavior akin to the synchronization phenomenon in Kuramoto models. In this work, we investigate the long-time clustering of mean-field transformer models. More precisely, we establish exponential rates of contraction to a Dirac point mass for any suitably regular initialization under some assumptions on the parameters of transformer models, any suitably regular mean-field initialization synchronizes exponentially fast with some quantitative rates.

Paper Structure

This paper contains 15 sections, 32 theorems, 267 equations, 4 figures.

Key Result

Proposition 2.1

Let $d \geq 3$. For any $\beta >0$, any local maxima of the interaction energy $\mathsf{E}_{\beta}$ is a global maxima of the form $\mu=\delta_{x_0}$ for some $x_0\in\mathbb{S}^{d-1}$.

Figures (4)

  • Figure 1: Illustration of $\mu_0$ in \ref{['example:delta mass not synchronize']} with $\xi=.7$. Circle radii are proportional to mass at each point. The cross indicates the mean of $\mu_0$ with $R_0>.7$ and the arrows indicate velocity fields for initial angles.
  • Figure 2: Illustration of $f_0$ in \ref{['example:large epa not synchronize']}. The cross indicates the mean of $f_0$ with $R_0>0$, and the arrows indicate velocity fields at the boundaries of the support of $f_0$.
  • Figure 3: Illustration of gnomonic projection.
  • Figure 4: Examples of non-monotonic evolution of $R_t$. The top row shows $R_t^2$, and the bottom row shows $\partial_t(R_t^2)$, for two different initial profiles. Each column corresponds to a different initial profile. These plots illustrate that $R_t^2$ can exhibit non-monotonic behavior, with $\partial_t(R_t^2)$ taking both positive and negative values over time. $\phi(\langle A x, y\rangle ) = e^{0.1\langle x, y \rangle }$ in the plots.

Theorems & Definitions (64)

  • Proposition 2.1
  • Example 2.2: No Łojasiewicz inequality for $\mathsf{E}_{\beta}$
  • Theorem 2.3: Polyak-Łojasiewicz inequality on a spherical cap
  • Theorem 2.4
  • Example 2.5: No synchronization for finite particles
  • Example 2.6: No mean-field synchronization for large $\beta$
  • Theorem 3.1
  • Lemma 3.2: First Variation Formula for $\mathsf{E}_{\phi}$
  • proof
  • Lemma 3.3: Second Variation Formula for $\mathsf{E}_{\phi}{[}\cdot{]}$
  • ...and 54 more