Learning Linear Attention in Polynomial Time
Morris Yau, Ekin Akyürek, Jiayuan Mao, Joshua B. Tenenbaum, Stefanie Jegelka, Jacob Andreas
TL;DR
The paper tackles whether linear attention mechanisms in transformers can be learned efficiently from data. It recasts multi-head linear attention (MHLA) as a linear predictor in a fixed RKHS via a cubic feature map, then reduces learning to polynomial-time linear regression in an expanded space, recovering MHLA parameters by SVD with at most $d^2$ heads. A certifiable identifiability condition based on the second-moment matrix $\Lambda_D$ guarantees that all empirical-risk minimizers compute the same function, enabling robust generalization including to universal Turing machines with bounded histories. The authors validate the theory with experiments showing that extra heads can accelerate SGD optimization, and that the identifiability certificate correlates with generalization across associative-memory and DFA-like tasks, highlighting a bridge between expressivity and learnability for transformers.
Abstract
Previous research has explored the computational expressivity of Transformer models in simulating Boolean circuits or Turing machines. However, the learnability of these simulators from observational data has remained an open question. Our study addresses this gap by providing the first polynomial-time learnability results (specifically strong, agnostic PAC learning) for single-layer Transformers with linear attention. We show that linear attention may be viewed as a linear predictor in a suitably defined RKHS. As a consequence, the problem of learning any linear transformer may be converted into the problem of learning an ordinary linear predictor in an expanded feature space, and any such predictor may be converted back into a multiheaded linear transformer. Moving to generalization, we show how to efficiently identify training datasets for which every empirical risk minimizer is equivalent (up to trivial symmetries) to the linear Transformer that generated the data, thereby guaranteeing the learned model will correctly generalize across all inputs. Finally, we provide examples of computations expressible via linear attention and therefore polynomial-time learnable, including associative memories, finite automata, and a class of Universal Turing Machine (UTMs) with polynomially bounded computation histories. We empirically validate our theoretical findings on three tasks: learning random linear attention networks, key--value associations, and learning to execute finite automata. Our findings bridge a critical gap between theoretical expressivity and learnability of Transformers, and show that flexible and general models of computation are efficiently learnable.
