Table of Contents
Fetching ...

Learning and Transferring Sparse Contextual Bigrams with Linear Transformers

Yunwei Ren, Zixuan Wang, Jason D. Lee

TL;DR

It is proved that, provided a nontrivial correlation between the downstream and pretraining tasks, finetuning from a pretrained model allows us to bypass the initial sample-intensive stage, and the algorithm can outperform SGD in this setting.

Abstract

Transformers have excelled in natural language modeling and one reason behind this success is their exceptional ability to combine contextual informal and global knowledge. However, the theoretical basis remains unclear. In this paper, first we introduce the Sparse Contextual Bigram (SCB), a natural extension of the classical bigram model, where the next token's generation depends on a sparse set of earlier positions determined by the last token. We then analyze the training dynamics and sample complexity of learning SCB using a one-layer linear transformer with a gradient-based algorithm. We show that when trained from scratch, the training process can be split into an initial sample-intensive stage where the correlation is boosted from zero to a nontrivial value, followed by a more sample-efficient stage of further improvement. Additionally, we prove that, provided a nontrivial correlation between the downstream and pretraining tasks, finetuning from a pretrained model allows us to bypass the initial sample-intensive stage. We also empirically demonstrate that our algorithm can outperform SGD in this setting and discuss its relationship with the usual softmax-based transformers.

Learning and Transferring Sparse Contextual Bigrams with Linear Transformers

TL;DR

It is proved that, provided a nontrivial correlation between the downstream and pretraining tasks, finetuning from a pretrained model allows us to bypass the initial sample-intensive stage, and the algorithm can outperform SGD in this setting.

Abstract

Transformers have excelled in natural language modeling and one reason behind this success is their exceptional ability to combine contextual informal and global knowledge. However, the theoretical basis remains unclear. In this paper, first we introduce the Sparse Contextual Bigram (SCB), a natural extension of the classical bigram model, where the next token's generation depends on a sparse set of earlier positions determined by the last token. We then analyze the training dynamics and sample complexity of learning SCB using a one-layer linear transformer with a gradient-based algorithm. We show that when trained from scratch, the training process can be split into an initial sample-intensive stage where the correlation is boosted from zero to a nontrivial value, followed by a more sample-efficient stage of further improvement. Additionally, we prove that, provided a nontrivial correlation between the downstream and pretraining tasks, finetuning from a pretrained model allows us to bypass the initial sample-intensive stage. We also empirically demonstrate that our algorithm can outperform SGD in this setting and discuss its relationship with the usual softmax-based transformers.

Paper Structure

This paper contains 45 sections, 39 theorems, 222 equations, 6 figures, 1 algorithm.

Key Result

Theorem 3.1

Let $\varepsilon > 0$ be our target accuracy and $\mathcal{T}_1 = \min\{ \tau \ge 0 \,:\, \max\{ \alpha_{V, \tau}, \alpha_{A, \tau} \} \ge \Theta(1/(QN)) \}$. We can choose the hyperparameters in Algorithm alg: training_alg such that within $\mathop{\mathrm{poly}}\nolimits(N, Q, 1/\varepsilon, \log

Figures (6)

  • Figure 1: Convergence analysis: We plot the distance to the ground truth $\|\bm{V}-\bm{P}\|_\mu, \|{\bm{A}}-\bm{Q}\|_\mu$ in different settings. After stage 1 ends at $\tau=400$ (when $\alpha_A,\alpha_V\approx 0.1$), we use vanilla SGD and our proximal gradient method to train the transformer. Compared with SGD, the $\ell_1$ regularized proximal gradient descent quickly converges, and the final solution (the star) recovers the ground truth. SGD either suffers from the large gradient variance (when $\eta_2$ is large) or a slow convergence rate (small $\eta_2'$).
  • Figure 2: Similarity between the softmax and linear attention. We train two transformers with (1) (Left) softmax attention and (2) (Middle) linear attention layer on the SCB tasks with the same ground-truth ($T=50, N=10, Q=2$). The attention pattern and the value matrix (learned transition matrix) are very similar (left two plots) and they converge to approximately the same loss (right plot).
  • Figure 3: Signals $\alpha_A,\alpha_V$ and the distance to population process $\bm{\Delta}_{{\bm{A}}},\bm{\Delta}_{\bm{V}}$. For the SGD, the distance to the population process of the attention matrix ${\bm{A}}$ keeps growing and dominates the signal term. That explains the failure to learn the correct attention pattern, which leads to saturation of the signal. In comparison, our proximal methods dramatically help reduce the gradient noise and keep close to the population process. Though $\|\bm{\Delta}_A\|$ eventually grows up due to the bias of the gradient estimate (the original signal growth is also slowed down), after normalization it can still approximately learn the correct pattern. Both $\|\bm{\Delta}_{\bm{V}}\|$ stay small empirically.
  • Figure 4: Similarity with the ground-truth. The figure shows after Stage 1, normalization helps further improve the solution of the proximal method. Meanwhile, with or without normalization, our proximal method always outperforms the vanilla SGD, which fails to recover the ground-truth.
  • Figure 5: Simulation with larger $N$ and $T$. We simulate the SGD/$\ell_1$ regularized dynamics by replacing the batched noise with Gaussian noise in the dynamics formula in Lemma C.5 and C.6. The gaussian noise variance scales with the inverse of batch size. The experiments show that the conclusions drawn from the small $N$ cases still hold in those simulations: when $T = 100000, N =100/500$, our $\ell_1$ regularized algorithm can recover the ground-truth since the distance to the population trajectory ($\Delta_A,\Delta_V$) stays very small, while the error along SGD trajectories quickly increases with the same batch size.
  • ...and 1 more figures

Theorems & Definitions (79)

  • proof : Remark on condition (c)
  • Theorem 3.1: Theorem \ref{['thm: main']}
  • Theorem 4.1: informal version of Theorem \ref{['thm: transfer learning: main']}
  • Lemma B.1
  • Lemma B.2
  • Lemma B.3
  • Lemma B.4
  • Lemma B.5
  • Lemma B.6
  • Lemma B.7: Expected gradients
  • ...and 69 more