Table of Contents
Fetching ...

Emergence of meta-stable clustering in mean-field transformer models

Giuseppe Bruno, Federico Pasqualotto, Andrea Agazzi

TL;DR

A perturbative analysis of the mean-field PDE around the iid uniform initialization is performed and it is proved that, in the limit of large number of tokens, the model remains close to a meta-stable manifold of solutions with a given structure.

Abstract

We model the evolution of tokens within a deep stack of Transformer layers as a continuous-time flow on the unit sphere, governed by a mean-field interacting particle system, building on the framework introduced in (Geshkovski et al., 2023). Studying the corresponding mean-field Partial Differential Equation (PDE), which can be interpreted as a Wasserstein gradient flow, in this paper we provide a mathematical investigation of the long-term behavior of this system, with a particular focus on the emergence and persistence of meta-stable phases and clustering phenomena, key elements in applications like next-token prediction. More specifically, we perform a perturbative analysis of the mean-field PDE around the iid uniform initialization and prove that, in the limit of large number of tokens, the model remains close to a meta-stable manifold of solutions with a given structure (e.g., periodicity). Further, the structure characterizing the meta-stable manifold is explicitly identified, as a function of the inverse temperature parameter of the model, by the index maximizing a certain rescaling of Gegenbauer polynomials.

Emergence of meta-stable clustering in mean-field transformer models

TL;DR

A perturbative analysis of the mean-field PDE around the iid uniform initialization is performed and it is proved that, in the limit of large number of tokens, the model remains close to a meta-stable manifold of solutions with a given structure.

Abstract

We model the evolution of tokens within a deep stack of Transformer layers as a continuous-time flow on the unit sphere, governed by a mean-field interacting particle system, building on the framework introduced in (Geshkovski et al., 2023). Studying the corresponding mean-field Partial Differential Equation (PDE), which can be interpreted as a Wasserstein gradient flow, in this paper we provide a mathematical investigation of the long-term behavior of this system, with a particular focus on the emergence and persistence of meta-stable phases and clustering phenomena, key elements in applications like next-token prediction. More specifically, we perform a perturbative analysis of the mean-field PDE around the iid uniform initialization and prove that, in the limit of large number of tokens, the model remains close to a meta-stable manifold of solutions with a given structure (e.g., periodicity). Further, the structure characterizing the meta-stable manifold is explicitly identified, as a function of the inverse temperature parameter of the model, by the index maximizing a certain rescaling of Gegenbauer polynomials.

Paper Structure

This paper contains 28 sections, 27 theorems, 146 equations, 6 figures.

Key Result

Theorem 3.1

Assume that there exists $\mu_0 \in \mathcal{P}(\mathbb S^{d-1})$ such that $W_1(\mu_{\Xi^{(N)}}(0), \mu_0) \to 0$ as $N \to \infty$. Let $\mu(t)$ be the unique weak solution of the associated mean field dynamics (eq:cont_pde) with initial condition $\mu(0) = \mu_0$. Then, for any fixed $t \geq 0$,

Figures (6)

  • Figure 1: Schematic representation of our decomposition of the dynamics. Here, $\mu$ represents the "true" evolution of the system, while $\mu_L$ represents the linearized evolution around the uniform measure $\mu_\infty$; $f^\alpha$ denotes the evolution of the mean-field PDE with initial conditions $f_0^\alpha$ on the invariant manifold selected by $k_{max}$, and $f^{\alpha,K}$ is the approximation by Grenier's iterative scheme.
  • Figure 2: On the left, the plots of $\gamma_k$ as a function of $k$ for $\beta = 5$ (top) and $\beta = 7$ (bottom) are shown. The dashed blue line represents the value of $\gamma_k$ generalized to real $k$. From these graphs, we observe that the corresponding values of $k_{\text{max}}$ are $3$ and $4$, respectively, representing the predicted number of clusters in the large-$N$ limit. In the center, numerical simulations depict particle trajectories for $\beta = 5$ (top) and $\beta = 7$ (bottom), with $10^4$ particles whose initial conditions are sampled uniformly at random. On the right, the histograms of particle distributions at the end of the simulations are displayed, showcasing the formation of $3$ and $4$ clusters respectively.
  • Figure 3: Plot of the average time required for the empirical measure to exceed a fixed threshold. Note the logarithmic scale on the $x$-axis.
  • Figure 4: Evolution of the initial condition given by the uniform measure perturbed by white noise ($\sigma = 0.01$) with $\beta=5$ . Note that the emerging solution is $3$-periodic as predicted by Theorem \ref{['thm:main2']}.
  • Figure 5: Plots showing the evolution of the Wasserstein distance between $\mu_t$ and the uniform measure $\mu_\infty$ (solid line) and between $\mu_t$ and $\delta_3:=\sum_{k=0}^2\delta_{2\pi k/3}$ (dashed line). Note that the simulations with the initial condition $f^\alpha$ for $\alpha=0.01$ and $\alpha=0.001$ almost completely overlap, differing only by a time shift. This observation supports the validity of Assumption \ref{['ass:3']}.
  • ...and 1 more figures

Theorems & Definitions (61)

  • Theorem 3.1: Mean field limit
  • Remark 3.2
  • Remark 4.1
  • Theorem 4.2: Linear phase
  • Theorem 4.3: Quasi-linear phase
  • Remark 4.4
  • Theorem 4.5: Clustering phase
  • Remark 4.6
  • Theorem 4.7: Quasi-linear phase
  • Theorem A.1
  • ...and 51 more