Table of Contents
Fetching ...

On the Benefits of Rank in Attention Layers

Noah Amsel, Gilad Yehudai, Joan Bruna

TL;DR

This work analyzes the expressive power of attention with respect to the query/key rank $r$ and the number of heads $H$. It proves a rank separation for a nearest-neighbor target: a single full-rank head can approximate it arbitrarily well, whereas low-rank heads require exponential growth in $d$ (or $(d/r)^{1/()}$-type scales) of $H$ to achieve comparable accuracy. The authors further show that depth can mitigate this weakness for short contexts, enabling polynomially many rank-1 heads to approximate the target, though the same approach may not extend to long contexts. Experiments with off-the-shelf transformers corroborate the theory, revealing that standard $H=d/r$ scaling may understate the expressive power of full-rank attention and that low-rank attention can be significantly weaker under practical budgets. Overall, the paper calls for rethinking hyperparameter choices in transformers and highlights depth as a potential remedy for some low-rank limitations, especially in shorter contexts.

Abstract

Attention-based mechanisms are widely used in machine learning, most prominently in transformers. However, hyperparameters such as the rank of the attention matrices and the number of heads are scaled nearly the same way in all realizations of this architecture, without theoretical justification. In this work we show that there are dramatic trade-offs between the rank and number of heads of the attention mechanism. Specifically, we present a simple and natural target function that can be represented using a single full-rank attention head for any context length, but that cannot be approximated by low-rank attention unless the number of heads is exponential in the embedding dimension, even for short context lengths. Moreover, we prove that, for short context lengths, adding depth allows the target to be approximated by low-rank attention. For long contexts, we conjecture that full-rank attention is necessary. Finally, we present experiments with off-the-shelf transformers that validate our theoretical findings.

On the Benefits of Rank in Attention Layers

TL;DR

This work analyzes the expressive power of attention with respect to the query/key rank and the number of heads . It proves a rank separation for a nearest-neighbor target: a single full-rank head can approximate it arbitrarily well, whereas low-rank heads require exponential growth in (or -type scales) of to achieve comparable accuracy. The authors further show that depth can mitigate this weakness for short contexts, enabling polynomially many rank-1 heads to approximate the target, though the same approach may not extend to long contexts. Experiments with off-the-shelf transformers corroborate the theory, revealing that standard scaling may understate the expressive power of full-rank attention and that low-rank attention can be significantly weaker under practical budgets. Overall, the paper calls for rethinking hyperparameter choices in transformers and highlights depth as a potential remedy for some low-rank limitations, especially in shorter contexts.

Abstract

Attention-based mechanisms are widely used in machine learning, most prominently in transformers. However, hyperparameters such as the rank of the attention matrices and the number of heads are scaled nearly the same way in all realizations of this architecture, without theoretical justification. In this work we show that there are dramatic trade-offs between the rank and number of heads of the attention mechanism. Specifically, we present a simple and natural target function that can be represented using a single full-rank attention head for any context length, but that cannot be approximated by low-rank attention unless the number of heads is exponential in the embedding dimension, even for short context lengths. Moreover, we prove that, for short context lengths, adding depth allows the target to be approximated by low-rank attention. For long contexts, we conjecture that full-rank attention is necessary. Finally, we present experiments with off-the-shelf transformers that validate our theoretical findings.
Paper Structure (42 sections, 31 theorems, 198 equations, 5 figures, 1 table)

This paper contains 42 sections, 31 theorems, 198 equations, 5 figures, 1 table.

Key Result

Theorem 2

There exist universal constants $c, c', C$ and $C'$ such that if either of the following sets of assumptions hold: Then, for any choice of $H$ rank-$r$ generalized attention heads $\phi_h : \mathbb{R}^{r \times 2} \to \Delta^1, {\bm{V}}_h \in \mathbb{R}^{d \times d}, {\bm{K}}_h \in \mathbb{R}^{d \times r}$ the error of approximating the nearest neighbor function is bounded as follows where $f$ i

Figures (5)

  • Figure 1: Standard transformers trained on the farthest neighbor function. The dimension is $d = 64$ and the number of input points is $N = 16$. Line shows best of five runs (except for $L=3, \text{params} = d^3, r \in \{16, 32\}$, which are best of eight). Across different numbers of layers and heads, high-rank models significantly outperform low-rank models with the same number of parameters.
  • Figure 2: Properties of learned ${\bm{K}} {\bm{Q}}^\top$ matrices for full-rank models with one layer. Boxplots show distribution over heads from five runs, each on a model which has between 1 and 64 full-rank heads. Left panel plots Frobenius angle with the identity: $\arccos\left(\left\langle {\bm{K}} {\bm{Q}}^\top, {\bm{I}} \right\rangle_{\mathsf F} / (\|{\bm{K}} {\bm{Q}}^\top\|_{\mathsf F} \|{\bm{I}}\|_{\mathsf F})\right)$. Results show that ${\bm{K}}{\bm{Q}}^\top$ nearly equals $-c{\bm{I}}$ for $c > 1000$ in all cases.
  • Figure 3: Standard transformers with positional encodings ($d=64$, $N=16$). Line shows best of five runs; shaded region shows range over five runs. Positional encodings help when the encodings are concatenated to the inputs and there are multiple layers (cf. \ref{['thm:majority_positional']}). Otherwise, they do not help.
  • Figure 4: Effect of the number of points ($N$) on the difficulty of learning the farthest neighbor function. Full-rank attention learns an accurate representation across many $N$s, but the performance of low-rank attention degrades as $N$ grows. Dimension is $64$. All models have two layers with $H = d^2 / r$ heads each. Line shows best of five runs; shaded region shows range over five runs.
  • Figure 5: Approximation to $u(\cdot)$ of \ref{['eq:u_t']} for several dimensions, using degree-50 ultraspherical expansion. Heads with $\angle({\bm{q}}, {\bm{k}}) = \theta$ are equivalent to those with angles $\theta \pm \pi$ up to a sign flip. For large dimension, the distribution over $\angle({\bm{q}}, {\bm{k}})$ induced by $u$ approaches a Gaussian with mean $0$.

Theorems & Definitions (67)

  • Theorem 2: Low-Rank Approximation Lower Bounds, Equivariant Case
  • Theorem 4: Low-rank Approximation Lower Bounds, biased case
  • Remark 5: Bound on the weights
  • Definition 6
  • Theorem 7
  • Conjecture 8
  • Definition 9
  • Definition 10
  • Lemma 11
  • proof
  • ...and 57 more