Table of Contents
Fetching ...

Provably learning a multi-head attention layer

Sitan Chen, Yuanzhi Li

TL;DR

This work establishes the first computational-learnability guarantees for nonlinear multi-head attention under a benign Boolean input model. It introduces a six-phase algorithm that (i) crudely estimates the sum of projection matrices, (ii) sculpts an affine hull that captures candidate attention heads, (iii) uses a minimum-norm proxy to refine this hull, (iv) recovers the span of the attention matrices, and (v) performs LP-based certification and least-squares refinement to obtain accurate head and projection matrices. The approach hinges on new uses of tail/anti-concentration bounds for the input distribution and a geometric viewpoint on the parameter space, in contrast to traditional moment-based methods used for feed-forward nets. The paper also proves exponential-in-$m$ computational lower bounds (via SQ and cryptographic assumptions), indicating that the learned-parameter dependence on the number of heads is intrinsic in the worst case. Together, these results illuminate both the capabilities and limits of provable learning for transformers and point to promising directions for removing assumptions, understanding gradient-based optimization, and scaling to deeper architectures.

Abstract

The multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models. Given a sequence length $k$, attention matrices $\mathbfΘ_1,\ldots,\mathbfΘ_m\in\mathbb{R}^{d\times d}$, and projection matrices $\mathbf{W}_1,\ldots,\mathbf{W}_m\in\mathbb{R}^{d\times d}$, the corresponding multi-head attention layer $F: \mathbb{R}^{k\times d}\to \mathbb{R}^{k\times d}$ transforms length-$k$ sequences of $d$-dimensional tokens $\mathbf{X}\in\mathbb{R}^{k\times d}$ via $F(\mathbf{X}) \triangleq \sum^m_{i=1} \mathrm{softmax}(\mathbf{X}\mathbfΘ_i\mathbf{X}^\top)\mathbf{X}\mathbf{W}_i$. In this work, we initiate the study of provably learning a multi-head attention layer from random examples and give the first nontrivial upper and lower bounds for this problem: - Provided $\{\mathbf{W}_i, \mathbfΘ_i\}$ satisfy certain non-degeneracy conditions, we give a $(dk)^{O(m^3)}$-time algorithm that learns $F$ to small error given random labeled examples drawn uniformly from $\{\pm 1\}^{k\times d}$. - We prove computational lower bounds showing that in the worst case, exponential dependence on $m$ is unavoidable. We focus on Boolean $\mathbf{X}$ to mimic the discrete nature of tokens in large language models, though our techniques naturally extend to standard continuous settings, e.g. Gaussian. Our algorithm, which is centered around using examples to sculpt a convex body containing the unknown parameters, is a significant departure from existing provable algorithms for learning feedforward networks, which predominantly exploit algebraic and rotation invariance properties of the Gaussian distribution. In contrast, our analysis is more flexible as it primarily relies on various upper and lower tail bounds for the input distribution and "slices" thereof.

Provably learning a multi-head attention layer

TL;DR

This work establishes the first computational-learnability guarantees for nonlinear multi-head attention under a benign Boolean input model. It introduces a six-phase algorithm that (i) crudely estimates the sum of projection matrices, (ii) sculpts an affine hull that captures candidate attention heads, (iii) uses a minimum-norm proxy to refine this hull, (iv) recovers the span of the attention matrices, and (v) performs LP-based certification and least-squares refinement to obtain accurate head and projection matrices. The approach hinges on new uses of tail/anti-concentration bounds for the input distribution and a geometric viewpoint on the parameter space, in contrast to traditional moment-based methods used for feed-forward nets. The paper also proves exponential-in- computational lower bounds (via SQ and cryptographic assumptions), indicating that the learned-parameter dependence on the number of heads is intrinsic in the worst case. Together, these results illuminate both the capabilities and limits of provable learning for transformers and point to promising directions for removing assumptions, understanding gradient-based optimization, and scaling to deeper architectures.

Abstract

The multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models. Given a sequence length , attention matrices , and projection matrices , the corresponding multi-head attention layer transforms length- sequences of -dimensional tokens via . In this work, we initiate the study of provably learning a multi-head attention layer from random examples and give the first nontrivial upper and lower bounds for this problem: - Provided satisfy certain non-degeneracy conditions, we give a -time algorithm that learns to small error given random labeled examples drawn uniformly from . - We prove computational lower bounds showing that in the worst case, exponential dependence on is unavoidable. We focus on Boolean to mimic the discrete nature of tokens in large language models, though our techniques naturally extend to standard continuous settings, e.g. Gaussian. Our algorithm, which is centered around using examples to sculpt a convex body containing the unknown parameters, is a significant departure from existing provable algorithms for learning feedforward networks, which predominantly exploit algebraic and rotation invariance properties of the Gaussian distribution. In contrast, our analysis is more flexible as it primarily relies on various upper and lower tail bounds for the input distribution and "slices" thereof.
Paper Structure (89 sections, 85 theorems, 551 equations)

This paper contains 89 sections, 85 theorems, 551 equations.

Key Result

Theorem 1.2

Let $F: \{\pm 1\}^{k\times d}\to {\mathbb{R}}^{k\times d}$ be a multi-head attention layer whose attention and projection matrices $\{(\boldsymbol{\Theta}_i,{\bf W}_i)\}^m_{i=1}$ are non-degenerate in the sense of Section sec:assume. Then given at least $N = (kd)^{\Theta(m)} + \mathrm{poly}(m,k,d)\c

Theorems & Definitions (172)

  • Remark 1.1
  • Theorem 1.2
  • Theorem 1.3: Informal, see Theorem \ref{['thm:main_lbd']}
  • Theorem 3.1: Hanson-Wright
  • Lemma 3.2
  • proof
  • Theorem 3.3: Theorem 1.1 in dvorak2022probability
  • Theorem 3.4: Corollary 20.1.4 from nagaev2002lower
  • Lemma 3.5: Berry-Esseen
  • Lemma 3.6
  • ...and 162 more