Table of Contents
Fetching ...

Linear Transformers are Versatile In-Context Learners

Max Vladymyrov, Johannes von Oswald, Mark Sandler, Rong Ge

TL;DR

It is proved that each layer of a linear transformer maintains a weight vector for an implicit linear regression problem and can be interpreted as performing a variant of preconditioned gradient descent.

Abstract

Recent research has demonstrated that transformers, particularly linear attention models, implicitly execute gradient-descent-like algorithms on data provided in-context during their forward inference step. However, their capability in handling more complex problems remains unexplored. In this paper, we prove that each layer of a linear transformer maintains a weight vector for an implicit linear regression problem and can be interpreted as performing a variant of preconditioned gradient descent. We also investigate the use of linear transformers in a challenging scenario where the training data is corrupted with different levels of noise. Remarkably, we demonstrate that for this problem linear transformers discover an intricate and highly effective optimization algorithm, surpassing or matching in performance many reasonable baselines. We analyze this algorithm and show that it is a novel approach incorporating momentum and adaptive rescaling based on noise levels. Our findings show that even linear transformers possess the surprising ability to discover sophisticated optimization strategies.

Linear Transformers are Versatile In-Context Learners

TL;DR

It is proved that each layer of a linear transformer maintains a weight vector for an implicit linear regression problem and can be interpreted as performing a variant of preconditioned gradient descent.

Abstract

Recent research has demonstrated that transformers, particularly linear attention models, implicitly execute gradient-descent-like algorithms on data provided in-context during their forward inference step. However, their capability in handling more complex problems remains unexplored. In this paper, we prove that each layer of a linear transformer maintains a weight vector for an implicit linear regression problem and can be interpreted as performing a variant of preconditioned gradient descent. We also investigate the use of linear transformers in a challenging scenario where the training data is corrupted with different levels of noise. Remarkably, we demonstrate that for this problem linear transformers discover an intricate and highly effective optimization algorithm, surpassing or matching in performance many reasonable baselines. We analyze this algorithm and show that it is a novel approach incorporating momentum and adaptive rescaling based on noise levels. Our findings show that even linear transformers possess the surprising ability to discover sophisticated optimization strategies.
Paper Structure (31 sections, 9 theorems, 36 equations, 7 figures, 1 table)

This paper contains 31 sections, 9 theorems, 36 equations, 7 figures, 1 table.

Key Result

Theorem 4.1

Suppose the output of a linear transformer at $l$-th layer is $(x^{l}_1, y^{l}_1), (x^{l}_2, y^{l}_2), ..., (x^{l}_n, y^{l}_n), (x^{l}_t, y^{l}_t)$, then there exists matrices $M^{l}$, vectors $u^{l}, w^{l}$ and scalars $a^{l}$ such that

Figures (7)

  • Figure 1: In-context learning performance for noisy linear regression problem across models with different number of layers and $\sigma_{max}$ for $\sigma_\tau\sim U(0,\sigma_{max})$. Each marker corresponds to a separately trained model with a given number of layers. Models with diagonal attention weights (Diag) match those with full attention weights (Full). Models specialized on a fixed noise (GD$^{++}$) perform poorly, similar to a Ridge Regression solution with a constant noise (ConstRR). Among the baselines, only tuned exact Ridge Regression solution (TunedRR) is comparable with linear transformers.
  • Figure 2: Per-variance profile of models behavior for uniform noise variance $\sigma_\tau\sim U(0,\sigma_{max})$. Top two rows: 7-layer models with varying $\sigma_{max}$. Bottom row: models with varying numbers of layers, fixed $\sigma_{max}=5$. In-distribution noise is shaded gray.
  • Figure 3: In-context learning performance for noisy linear regression across models with varying number of layers for conditional noise variance $\sigma_\tau \in \{1, 3\}$ and $\sigma_\tau \in \{1, 3, 5\}$. Top: loss for models with various number of layers and per-variance profile for models with 7 layers. Bottom: Per-variance profile of the model across different numbers of layers. In-distribution noise is shaded gray.
  • Figure 4: Weights for 4 layer linear transformer with Full parametrization trained with categorical noise $\sigma_\tau \in \{1, 3\}$. Top: weights for $Q^l$ matrix, bottom: weights for $P^l$ matrix.
  • Figure 5: Linear transformer models show a consistent decrease in error per layer when trained on data with mixed noise variance $\sigma_\tau\sim U(0,5)$. The error bars measure variance over $5$ training seeds.
  • ...and 2 more figures

Theorems & Definitions (16)

  • Theorem 4.1
  • Theorem 4.2
  • Lemma 4.3
  • Lemma 4.4
  • Theorem 5.1
  • Theorem 5.2
  • proof
  • Lemma A.1
  • proof
  • proof
  • ...and 6 more