Table of Contents
Fetching ...

Learning Linear Attention in Polynomial Time

Morris Yau, Ekin Akyürek, Jiayuan Mao, Joshua B. Tenenbaum, Stefanie Jegelka, Jacob Andreas

TL;DR

The paper tackles whether linear attention mechanisms in transformers can be learned efficiently from data. It recasts multi-head linear attention (MHLA) as a linear predictor in a fixed RKHS via a cubic feature map, then reduces learning to polynomial-time linear regression in an expanded space, recovering MHLA parameters by SVD with at most $d^2$ heads. A certifiable identifiability condition based on the second-moment matrix $\Lambda_D$ guarantees that all empirical-risk minimizers compute the same function, enabling robust generalization including to universal Turing machines with bounded histories. The authors validate the theory with experiments showing that extra heads can accelerate SGD optimization, and that the identifiability certificate correlates with generalization across associative-memory and DFA-like tasks, highlighting a bridge between expressivity and learnability for transformers.

Abstract

Previous research has explored the computational expressivity of Transformer models in simulating Boolean circuits or Turing machines. However, the learnability of these simulators from observational data has remained an open question. Our study addresses this gap by providing the first polynomial-time learnability results (specifically strong, agnostic PAC learning) for single-layer Transformers with linear attention. We show that linear attention may be viewed as a linear predictor in a suitably defined RKHS. As a consequence, the problem of learning any linear transformer may be converted into the problem of learning an ordinary linear predictor in an expanded feature space, and any such predictor may be converted back into a multiheaded linear transformer. Moving to generalization, we show how to efficiently identify training datasets for which every empirical risk minimizer is equivalent (up to trivial symmetries) to the linear Transformer that generated the data, thereby guaranteeing the learned model will correctly generalize across all inputs. Finally, we provide examples of computations expressible via linear attention and therefore polynomial-time learnable, including associative memories, finite automata, and a class of Universal Turing Machine (UTMs) with polynomially bounded computation histories. We empirically validate our theoretical findings on three tasks: learning random linear attention networks, key--value associations, and learning to execute finite automata. Our findings bridge a critical gap between theoretical expressivity and learnability of Transformers, and show that flexible and general models of computation are efficiently learnable.

Learning Linear Attention in Polynomial Time

TL;DR

The paper tackles whether linear attention mechanisms in transformers can be learned efficiently from data. It recasts multi-head linear attention (MHLA) as a linear predictor in a fixed RKHS via a cubic feature map, then reduces learning to polynomial-time linear regression in an expanded space, recovering MHLA parameters by SVD with at most heads. A certifiable identifiability condition based on the second-moment matrix guarantees that all empirical-risk minimizers compute the same function, enabling robust generalization including to universal Turing machines with bounded histories. The authors validate the theory with experiments showing that extra heads can accelerate SGD optimization, and that the identifiability certificate correlates with generalization across associative-memory and DFA-like tasks, highlighting a bridge between expressivity and learnability for transformers.

Abstract

Previous research has explored the computational expressivity of Transformer models in simulating Boolean circuits or Turing machines. However, the learnability of these simulators from observational data has remained an open question. Our study addresses this gap by providing the first polynomial-time learnability results (specifically strong, agnostic PAC learning) for single-layer Transformers with linear attention. We show that linear attention may be viewed as a linear predictor in a suitably defined RKHS. As a consequence, the problem of learning any linear transformer may be converted into the problem of learning an ordinary linear predictor in an expanded feature space, and any such predictor may be converted back into a multiheaded linear transformer. Moving to generalization, we show how to efficiently identify training datasets for which every empirical risk minimizer is equivalent (up to trivial symmetries) to the linear Transformer that generated the data, thereby guaranteeing the learned model will correctly generalize across all inputs. Finally, we provide examples of computations expressible via linear attention and therefore polynomial-time learnable, including associative memories, finite automata, and a class of Universal Turing Machine (UTMs) with polynomially bounded computation histories. We empirically validate our theoretical findings on three tasks: learning random linear attention networks, key--value associations, and learning to execute finite automata. Our findings bridge a critical gap between theoretical expressivity and learnability of Transformers, and show that flexible and general models of computation are efficiently learnable.

Paper Structure

This paper contains 39 sections, 26 theorems, 72 equations, 4 figures, 8 algorithms.

Key Result

Theorem 2.2

Let $D$ be a dataset $D = \{Z_i, y_i\}_{i \in [N]}$ drawn i.i.d. from a distribution $\mathcal{D}$ where each $Z_i \in \mathbb{R}^{d \times n_i}$, $y_i \in \mathbb{R}^d$. The embedding dimension $d$ is fixed across the dataset, whereas $n_i$ can be different for each datapoint. Let $n_{max} = \max_{ with sample complexity $N = O\left(\frac{1}{\epsilon}\left(d^4 + \log(\delta^{-1})\right)\right)$.

Figures (4)

  • Figure 1: Performance comparison of multi-head, multi-layer linear attention models and the original Transformer model (denoted as full). We trained using SGD on synthetic data generated from a single-layer linear attention model for varying training set sizes ($N$) and input dimensions ($d$), number of heads $m$, and number of layers $n$. Results demonstrate that multi-head architectures converge faster on different input dimensions and match the performance of our algorithm \ref{['alg:poly']} (convex algorithm). Increasing the number of layers or incorporating multilayer perceptrons (MLPs) and layer normalization did not yield consistent improvements. Shading indicates the standard error over three different runs.
  • Figure 2: Impact of data distribution on the associative lookup task performance: We generated training data for an associative lookup task bietti2023cabannes2024 using mixtures of two distributions: (1) Gaussian key and value vectors, and (2) random unitary key and value vectors. By adjusting the mixture probability, we can manipulate the certificate value (minimum eigenvalue of the data covariance matrix), as unitary key--value vectors give rank-deficient "certificates". (a) Algorithm \ref{['alg:poly']}: as the minimum eigenvalue increases, Algorithm \ref{['alg:poly']} converges more closely to the true parameters. (b) SGD: SGD learns parameters that are equivalent to the ground truth parameters in $p$ feature space for certifiably identifiable data, but for unidentifiable data, they are far apart in $p$ feature space and therefore compute different functions.
  • Figure 3: Performance comparison of multi-head, multi-layer linear attention models and the original Transformer model (denoted as full). We trained using SGD on synthetic data generated from a single-layer linear attention model for varying training set sizes ($N$) and input dimensions ($d$), number of heads $m$, and number of layers $n$. We present mean squared error of the predictions w.r.t number of training epochs. Results demonstrate that multi-head architectures converge faster on different input dimensions and match the performance of our algorithm \ref{['alg:poly']} (convex algorithm). Increasing the number of layers or incorporating multilayer perceptrons (MLPs) and layer normalization did not yield consistent improvements. Shading indicates the standard error over three different runs.
  • Figure 4: Data requirement for universal DFA simulation: We train a fixed sized Transformer (4-layers, 16 heads and 2048 hidden dimensions) to simulate a DFA given a transition table and input word. The vertical axis shows the number of tokens (expressed as word length $L$ times the number of examples $Q$) required to obtain 99% next token accuracy.

Theorems & Definitions (47)

  • Definition 2.1: Multi-Head Linear Attention
  • Theorem 2.2: Learnability of Linear Attention
  • Lemma 2.3: Certificate of Identifiability---Informal
  • Corollary 2.4
  • Lemma 3.0: Learning UTM from Certifiably Identifiable Data
  • Definition A.1: Identifiability
  • Definition A.2: Realizability
  • Lemma A.2: Certificate of Identifiability
  • Corollary A.3
  • Lemma A.3: Independent input noise yields identifiability
  • ...and 37 more