Table of Contents
Fetching ...

Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

Zixuan Wang, Stanley Wei, Daniel Hsu, Jason D. Lee

TL;DR

The paper tackles the problem of learnability for the sparse token selection task $ ext{STS}_q$ by analyzing gradient-descent dynamics of a one-layer transformer with stochastic positional encoding (SPE). It proves that, with width $m= ilde{O}(d+q ext{log}T)$, GD globally converges to a solution that solves $ ext{STS}_q$, while any FCN requires width $ ilde{oldsymbol{ m O}}(Td)$ to achieve comparable performance, establishing an exponential separation in expressive power. It also proves a length generalization guarantee under SPE, and provides empirical evidence showing convergence to ground-truth directions and superior out-of-distribution generalization compared with fixed positional encodings. These results collectively justify the inductive bias of transformers for this class of arithmetic tasks and offer insight into the benefits of randomized positional encodings for robust length generalization.

Abstract

The transformer architecture has prevailed in various deep learning settings due to its exceptional capabilities to select and compose structural information. Motivated by these capabilities, Sanford et al. proposed the sparse token selection task, in which transformers excel while fully-connected networks (FCNs) fail in the worst case. Building upon that, we strengthen the FCN lower bound to an average-case setting and establish an algorithmic separation of transformers over FCNs. Specifically, a one-layer transformer trained with gradient descent provably learns the sparse token selection task and, surprisingly, exhibits strong out-of-distribution length generalization. We provide empirical simulations to justify our theoretical findings.

Transformers Provably Learn Sparse Token Selection While Fully-Connected Nets Cannot

TL;DR

The paper tackles the problem of learnability for the sparse token selection task by analyzing gradient-descent dynamics of a one-layer transformer with stochastic positional encoding (SPE). It proves that, with width , GD globally converges to a solution that solves , while any FCN requires width to achieve comparable performance, establishing an exponential separation in expressive power. It also proves a length generalization guarantee under SPE, and provides empirical evidence showing convergence to ground-truth directions and superior out-of-distribution generalization compared with fixed positional encodings. These results collectively justify the inductive bias of transformers for this class of arithmetic tasks and offer insight into the benefits of randomized positional encodings for robust length generalization.

Abstract

The transformer architecture has prevailed in various deep learning settings due to its exceptional capabilities to select and compose structural information. Motivated by these capabilities, Sanford et al. proposed the sparse token selection task, in which transformers excel while fully-connected networks (FCNs) fail in the worst case. Building upon that, we strengthen the FCN lower bound to an average-case setting and establish an algorithmic separation of transformers over FCNs. Specifically, a one-layer transformer trained with gradient descent provably learns the sparse token selection task and, surprisingly, exhibits strong out-of-distribution length generalization. We provide empirical simulations to justify our theoretical findings.
Paper Structure (44 sections, 30 theorems, 251 equations, 8 figures)

This paper contains 44 sections, 30 theorems, 251 equations, 8 figures.

Key Result

Theorem 1

For any $2\leq q<T/4,$$\epsilon\in (0,\frac{dT}{100(T-q)q}),$$\eta\leq \frac{1}{20d^2}, \bm{x}_{\text{query}}=\mathbf{0}_d$, if we run gradient descent on the population loss in eqn: training objective for qsa with zero initialization $\bm{W}(0)=\textbf{0}_{(d+T)\times (d+T)},\bm{V}(0)=\textbf{0}_{d

Figures (8)

  • Figure 1: The above figures describe the training trajectory of the one-layer transformer model with attention layer $\bm{W}$ attending to the full matrix $[\bm{Z},{\bm{z}}_{\text{query}}]$. Left (Global convergence of the transformer): we plot the training loss for the one-layer transformer with stochastic PE, complementing it with the inverse loss plot. The training loss converges to the global minimum of 0, and the inverse loss increases linearly, indicating the $O(1/t)$ convergence rate. Right (Cosine similarity with the ground-truth): we plot the evolution of the cosine similarity between $\bm{W}$ and $\bm{W}^\star$ and between $\bm{V}$ and $\bm{V}^\star$ throughout training, where $\bm{W}^\star$ and $\bm{V}^\star$ are the ground-truth matrices. The two cosine similarity curves gradually converge to 1, indicating the one-layer transformer eventually converges to the desired direction.
  • Figure 2: Length generalization superiority of stochastic PE: We plot the out-of-distribution error throughout training on each of the four length generalization tasks ($T_{\text{test}}=250, 300, 350, 400$) when $T=200$ and $q=3$. Observe that the one-layer transformer with stochastic positional encoding has a clear advantage over the fixed positional encoding architecture in all four tasks: stochastic architecture converges after 10k steps, while the length generalization error of the fixed architecture does not go below some constant.
  • Figure 3: Interpretable training: For the full model \ref{['main eqn: practical tf']}, we present a heat map of the self-attention layer $\bm{W}$ and the value matrix $\bm{V}$ at initialization and after convergence. We initialize $\bm{W},\bm{V}$ randomly at $t=0$. After training, observe that only the sub-block of $\bm{W}$ that attends to the positional encodings $\bm{E}$ converges to the identity direction, while all other entries converge to 0; in $\bm{V}$, only the sub-block that attends to the input tokens $\bm{X}$ converges to identity direction with all other entries converging to 0.
  • Figure 4: The length generalization performance and OOD performance on unseen $q_{\text{test}}$-subsets. Top: Length generalization. Note that stochastic PE converges to 0 validation loss, whereas a fixed PE is unable to do so; all of the fixed PE end up with validation loss at least 0.15. Bottom: Generalization to unseen $q_{\text{test}}$-subsets. Note that while both stochastic and fixed PE can converge to 0 validation loss in the long run, stochastic PE converges slightly quicker, as seen by the zoomed in versions of the plots near the end of training. Additionally, the fixed PE's validation performance gets worse as $q_{\text{test}}$ increases.
  • Figure 5: The training trajectory of Adam. The length generalization advantage for stochastic positional encoding is similar to the description of \ref{['fig: first appendix exp']}. While Adam may allow the validation loss for the $q_{\text{test}}$ to converge to 0 in the long run, for all practical purposes related to early stopping, stochastic PE dominates in such OOD performance.
  • ...and 3 more figures

Theorems & Definitions (61)

  • Definition 1
  • Definition 2: One-hot PE
  • Definition 3: Near-orthogonal PE
  • Definition 4: Reparameterization
  • Definition 5: Transformer with stochastic positional encoding
  • Theorem 1: Joint training with one-hot positional encoding
  • Lemma 1: \ref{['Lemma: symmetry in different q-sparse subsets joint']}, informal
  • Theorem 2: Joint training with stochastic positional encoding
  • Lemma 2: Consequence of \ref{['Lemma: induction hypothesis for joint training stochastic PE']}, informal
  • Theorem 3
  • ...and 51 more