Provably learning a multi-head attention layer
Sitan Chen, Yuanzhi Li
TL;DR
This work establishes the first computational-learnability guarantees for nonlinear multi-head attention under a benign Boolean input model. It introduces a six-phase algorithm that (i) crudely estimates the sum of projection matrices, (ii) sculpts an affine hull that captures candidate attention heads, (iii) uses a minimum-norm proxy to refine this hull, (iv) recovers the span of the attention matrices, and (v) performs LP-based certification and least-squares refinement to obtain accurate head and projection matrices. The approach hinges on new uses of tail/anti-concentration bounds for the input distribution and a geometric viewpoint on the parameter space, in contrast to traditional moment-based methods used for feed-forward nets. The paper also proves exponential-in-$m$ computational lower bounds (via SQ and cryptographic assumptions), indicating that the learned-parameter dependence on the number of heads is intrinsic in the worst case. Together, these results illuminate both the capabilities and limits of provable learning for transformers and point to promising directions for removing assumptions, understanding gradient-based optimization, and scaling to deeper architectures.
Abstract
The multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models. Given a sequence length $k$, attention matrices $\mathbfΘ_1,\ldots,\mathbfΘ_m\in\mathbb{R}^{d\times d}$, and projection matrices $\mathbf{W}_1,\ldots,\mathbf{W}_m\in\mathbb{R}^{d\times d}$, the corresponding multi-head attention layer $F: \mathbb{R}^{k\times d}\to \mathbb{R}^{k\times d}$ transforms length-$k$ sequences of $d$-dimensional tokens $\mathbf{X}\in\mathbb{R}^{k\times d}$ via $F(\mathbf{X}) \triangleq \sum^m_{i=1} \mathrm{softmax}(\mathbf{X}\mathbfΘ_i\mathbf{X}^\top)\mathbf{X}\mathbf{W}_i$. In this work, we initiate the study of provably learning a multi-head attention layer from random examples and give the first nontrivial upper and lower bounds for this problem: - Provided $\{\mathbf{W}_i, \mathbfΘ_i\}$ satisfy certain non-degeneracy conditions, we give a $(dk)^{O(m^3)}$-time algorithm that learns $F$ to small error given random labeled examples drawn uniformly from $\{\pm 1\}^{k\times d}$. - We prove computational lower bounds showing that in the worst case, exponential dependence on $m$ is unavoidable. We focus on Boolean $\mathbf{X}$ to mimic the discrete nature of tokens in large language models, though our techniques naturally extend to standard continuous settings, e.g. Gaussian. Our algorithm, which is centered around using examples to sculpt a convex body containing the unknown parameters, is a significant departure from existing provable algorithms for learning feedforward networks, which predominantly exploit algebraic and rotation invariance properties of the Gaussian distribution. In contrast, our analysis is more flexible as it primarily relies on various upper and lower tail bounds for the input distribution and "slices" thereof.
