Table of Contents
Fetching ...

A phase transition between positional and semantic learning in a solvable model of dot-product attention

Hugo Cui, Freya Behrens, Florent Krzakala, Lenka Zdeborová

TL;DR

The paper tackles how semantic attention emerges and competes with positional learning in transformers by analyzing a solvable tied low-rank dot-product attention model in a high-dimensional regime. It derives a tight closed-form characterization of the global minimum via a replica/GAMP framework, showing the minimum corresponds to either positional or semantic attention and that a phase transition in sample complexity $\alpha=n/d=\Theta(1)$ governs which mechanism dominates. The study further shows that a purely positional linear baseline can be outperformed by dot-product attention once semantic content is learned, highlighting the data-dependent advantage of attention architectures. These results illuminate when and why semantic reasoning arises in attention mechanisms and offer a quantitative foundation for mechanistic interpretability in high-dimensional neural networks.

Abstract

Many empirical studies have provided evidence for the emergence of algorithmic mechanisms (abilities) in the learning of language models, that lead to qualitative improvements of the model capabilities. Yet, a theoretical characterization of how such mechanisms emerge remains elusive. In this paper, we take a step in this direction by providing a tight theoretical analysis of the emergence of semantic attention in a solvable model of dot-product attention. More precisely, we consider a non-linear self-attention layer with trainable tied and low-rank query and key matrices. In the asymptotic limit of high-dimensional data and a comparably large number of training samples we provide a tight closed-form characterization of the global minimum of the non-convex empirical loss landscape. We show that this minimum corresponds to either a positional attention mechanism (with tokens attending to each other based on their respective positions) or a semantic attention mechanism (with tokens attending to each other based on their meaning), and evidence an emergent phase transition from the former to the latter with increasing sample complexity. Finally, we compare the dot-product attention layer to a linear positional baseline, and show that it outperforms the latter using the semantic mechanism provided it has access to sufficient data.

A phase transition between positional and semantic learning in a solvable model of dot-product attention

TL;DR

The paper tackles how semantic attention emerges and competes with positional learning in transformers by analyzing a solvable tied low-rank dot-product attention model in a high-dimensional regime. It derives a tight closed-form characterization of the global minimum via a replica/GAMP framework, showing the minimum corresponds to either positional or semantic attention and that a phase transition in sample complexity governs which mechanism dominates. The study further shows that a purely positional linear baseline can be outperformed by dot-product attention once semantic content is learned, highlighting the data-dependent advantage of attention architectures. These results illuminate when and why semantic reasoning arises in attention mechanisms and offer a quantitative foundation for mechanistic interpretability in high-dimensional neural networks.

Abstract

Many empirical studies have provided evidence for the emergence of algorithmic mechanisms (abilities) in the learning of language models, that lead to qualitative improvements of the model capabilities. Yet, a theoretical characterization of how such mechanisms emerge remains elusive. In this paper, we take a step in this direction by providing a tight theoretical analysis of the emergence of semantic attention in a solvable model of dot-product attention. More precisely, we consider a non-linear self-attention layer with trainable tied and low-rank query and key matrices. In the asymptotic limit of high-dimensional data and a comparably large number of training samples we provide a tight closed-form characterization of the global minimum of the non-convex empirical loss landscape. We show that this minimum corresponds to either a positional attention mechanism (with tokens attending to each other based on their respective positions) or a semantic attention mechanism (with tokens attending to each other based on their meaning), and evidence an emergent phase transition from the former to the latter with increasing sample complexity. Finally, we compare the dot-product attention layer to a linear positional baseline, and show that it outperforms the latter using the semantic mechanism provided it has access to sufficient data.
Paper Structure (61 sections, 115 equations, 15 figures, 1 table, 2 algorithms)

This paper contains 61 sections, 115 equations, 15 figures, 1 table, 2 algorithms.

Figures (15)

  • Figure 1: A phase transition in a toy model of attention.(A) We investigate a tied low-rank attention model in a teacher-student setting. The teacher mixes the $L$ individual tokens of dimension $d$ according to a semantic (as a function of the token's content $\mathbf x$) and a positional (as a function of the token's position) attention matrix. The student can only use positional encodings $\mathbf p$ to fit the positional properties of the teacher. (B) Schematic view of the loss landscape of the teacher, which contains both a positional and a semantic minimum. (C) We find that in the asymptotic high-dimensional limit and as a function of the sample complexity and the composition of the teacher, the global minimum switches, constituting a phase transition between positional and semantic learning.
  • Figure 2: Mixed positional/semantic teacher for $\omega=0.3$. Setting is $r_s=r_t=1,L=2, A=((0.6,0.4),(0.4,0.6)),\boldsymbol{\Sigma}_1=\boldsymbol{\Sigma}_2=0.25 \mathbb{I}_d$, $\boldsymbol{p}_1=\mathbf{1}_d=-\boldsymbol{p}_2$ and $\boldsymbol{Q}_\star\sim\mathcal{N}(0,\mathbb{I}_d)$. (left) Solid lines: difference in training loss $\Delta\epsilon_t$ between the semantic and positional solutions of \ref{['eq:replica_SP']} in Result \ref{['res:Asymptotics']}. Markers: difference in training loss at convergence achieved by training the model \ref{['eq:student']} using gradient descent initialized resp. at $\boldsymbol{Q}_\star$ and at $\boldsymbol{p}_1$. Marker color as in Fig. \ref{['fig:sem-pos-crossover-phase-diagram']}. (center)overlap $\theta$ between the learnt weights $\hat{\boldsymbol{Q}}$ and the target weights $\boldsymbol{Q}_\star$overlap $m$ between the learnt weights $\hat{\boldsymbol{Q}}$ and the positional embedding $p_1$. Solid lines represent the theoretical characterization of these two summary statistics provided by Result \ref{['res:Asymptotics']}. Only the solution of \ref{['eq:replica_SP']} corresponding to the lowest found training loss is represented (i.e. the positional solution for $\alpha<{\color{OliveGreen}\alpha_c}$ and the semantic otherwise). Markers represent experimental measures of these quantities, for gradient descent at convergence. Gradient descent was initialized at $\mathbf p_1$ for $\alpha<{\color{OliveGreen}\alpha_c}$ and at $\boldsymbol{Q}_\star$ for $\alpha>{\color{OliveGreen}\alpha_c}$. (right) We show the MSE achieved by the dense linear as $\epsilon_g^{min}$ (Result \ref{['res:Asymptotics']}), and MSE achieved by the dense linear baseline $\epsilon_g^{lin}$\ref{['eq:linear']} (Result \ref{['res:linear']}). Markers indicate the MSE experimentally reached by the model \ref{['eq:student']} trained using gradient descent, initialized previously for the overlaps. All experiments were performed in $d=1,000$ with the Pytorch implementation of full-batch gradient descent, for $T=5,000$ epochs and learning rate $\eta=0.15$. All points are averaged over $24$ instances of the problem each.
  • Figure 3: Phase transition between semantic and positional training loss. Setting and experiments were performed identical to Fig. \ref{['fig:free-entropy-crossover-omega0.3']}. (left) Scaling $d$ and $n$ jointly for $\alpha=1.5$ concentrates for $\theta$ and $m$, in different locations for the positional and semantic local minima each. We show 30 runs for each $d \in [10,15,23,36,56,87,135,209,323,500]$. (center) The color map represents the difference in training loss at convergence when training the model \ref{['eq:student']} using the Pytorch implementation of full-batch gradient descent, respectively from an initialization at $\mathbf p_1$ or at $\boldsymbol{Q}_\star$. The green dashed line represents the theoretical prediction for the threshold $\alpha_c(\omega)$ above which the semantic solution of \ref{['eq:replica_SP']} in Result \ref{['res:Asymptotics']} has lower loss than the positional solution. (right) The color map represents the difference in test MSE at convergence when training the attention model\ref{['eq:attention_student']} using the Pytorch implementation of full-batch gradient descent initialized at $\boldsymbol{Q}_\star$, and the dense linear baseline\ref{['eq:linear']}. The red dashed lines indicate the theoretical prediction --following from Result \ref{['res:Asymptotics']} and Result \ref{['eq:linear']}-- for the threshold sample complexity $\alpha_l(\omega)$ above which the dot-product attention \ref{['eq:student']} outperforms the baseline \ref{['eq:linear']}.
  • Figure 4: Several solutions exist for the histogram task. Elements of attention matrices for the histogram task for local minima in the empirical loss landscape. We generated a dataset of sequences by sampling each token of the sequence i.i.d. from the uniform distribution over all tokens. The target of a a given input sequence $\mathbf x = [A,D,D,C]$ is the number of occurence of each token in the complete sequence, i.e. $\mathbf y = [1,2,2,1]$. Models were trained with their respective frozen initialization using $n=35,000$ samples and the Adam optimizer. Top Row: The attention matrix of the positional solution is largely independent of the specific input sequence. Bottom Row: The attention matrices from the semantic solution vary based on the input token. Red squares highlight the elements of $A_{ij}$ where $x_i = x_j$.
  • Figure 5: Comparison of the attention layer activations for different sequences for $\theta_{pos}$ and $\tilde{ \theta}_{pos}$.
  • ...and 10 more figures