Table of Contents
Fetching ...

Clustering in Causal Attention Masking

Nikita Karagodin, Yury Polyanskiy, Philippe Rigollet

TL;DR

This work presents a modification of the self-attention dynamics proposed by Geshkovski et al. to better reflect the practically relevant, causally masked attention used in transformer architectures for generative AI, and proves asymptotic convergence to a single cluster for arbitrary key-query matrices and a value matrix equal to the identity.

Abstract

This work presents a modification of the self-attention dynamics proposed by Geshkovski et al. (arXiv:2312.10794) to better reflect the practically relevant, causally masked attention used in transformer architectures for generative AI. This modification translates into an interacting particle system that cannot be interpreted as a mean-field gradient flow. Despite this loss of structure, we significantly strengthen the results of Geshkovski et al. (arXiv:2312.10794) in this context: While previous rigorous results focused on cases where all three matrices (Key, Query, and Value) were scaled identities, we prove asymptotic convergence to a single cluster for arbitrary key-query matrices and a value matrix equal to the identity. Additionally, we establish a connection to the classical Rényi parking problem from combinatorial geometry to make initial theoretical steps towards demonstrating the existence of meta-stable states.

Clustering in Causal Attention Masking

TL;DR

This work presents a modification of the self-attention dynamics proposed by Geshkovski et al. to better reflect the practically relevant, causally masked attention used in transformer architectures for generative AI, and proves asymptotic convergence to a single cluster for arbitrary key-query matrices and a value matrix equal to the identity.

Abstract

This work presents a modification of the self-attention dynamics proposed by Geshkovski et al. (arXiv:2312.10794) to better reflect the practically relevant, causally masked attention used in transformer architectures for generative AI. This modification translates into an interacting particle system that cannot be interpreted as a mean-field gradient flow. Despite this loss of structure, we significantly strengthen the results of Geshkovski et al. (arXiv:2312.10794) in this context: While previous rigorous results focused on cases where all three matrices (Key, Query, and Value) were scaled identities, we prove asymptotic convergence to a single cluster for arbitrary key-query matrices and a value matrix equal to the identity. Additionally, we establish a connection to the classical Rényi parking problem from combinatorial geometry to make initial theoretical steps towards demonstrating the existence of meta-stable states.

Paper Structure

This paper contains 20 sections, 9 theorems, 98 equations, 6 figures, 1 table.

Key Result

Lemma 3.1

Let $x(t)$ be a solution of an ODE $\dot{x}(t) = \mathbf{P}_{x(t)}(Vx(t))$ defined on the unit sphere $\mathbb{S}^{d-1}$. Then, for almost every initial value $x(0) \in \mathbb{S}^{d-1}$, there exists $C,c>0$ such that the following convergence rates for the geodesic distance $\textsf{dist}$ hold:

Figures (6)

  • Figure 1: Particle trajectories for different Value matrices. In all cases we take simple Query and Key matrices $K = Q = I_d$, temperature $\beta = 9$ and final time $T = 5000$ for $n = 32$ particles initialized uniformly at random on the sphere. Positions of particles at time $T$ are indicated by a red dot.
  • Figure 2: Evolution of the system \ref{['eqn:csa']} with $K=Q=V=I_2$ with $n = 200$, $d = 2$, $\beta = 64$, strong Rényi centers (red) and Rényi centers (black) with $\delta = 4\beta^{-1/2}$. Note that strong Rényi centers are visually stationary (as per Lemma \ref{['lemma:meta']}) but do not explain all clusters. In turn, Rényi centers are moving and merging (one disappears between $t = 75$ and $t = 150$), but capture more meta-stable clusters.
  • Figure 3: Total percentage of particles consumed by Rényi and strong Rényi centers over time. Here we have plotted average, $0.1$ and $0.9$ quantiles over 5000 experiments with $n = 200$, $d = 2$, $\beta = 64$, $\delta = 4\beta^{-1/2}$.
  • Figure 4: Interaction function $h(x)$ for $\beta = 10$. The blue line is $h(x)$ and the red line is $x = \beta^{-1/2}$, close to $x = \tau_{\beta}^*$.
  • Figure 5: Eigenvalues of a value matrix at initialisation for albert-xlarge-v2.
  • ...and 1 more figures

Theorems & Definitions (21)

  • Lemma 3.1
  • Theorem 4.1
  • Conjecture 1
  • Conjecture 2
  • Lemma 5.1
  • Remark
  • Theorem 5.2: Clustering to frozen tokens for $K=Q=V=I_2$
  • Lemma 5.3
  • Remark
  • Lemma A.1: Properties of Interaction Functions
  • ...and 11 more