Table of Contents
Fetching ...

Tucker Attention: A generalization of approximate attention mechanisms

Timon Klein, Jonas Kusch, Sebastian Sager, Stefan Schnake, Steffen Schotthöfer

Abstract

The pursuit of reducing the memory footprint of the self-attention mechanism in multi-headed self attention (MHA) spawned a rich portfolio of methods, e.g., group-query attention (GQA) and multi-head latent attention (MLA). The methods leverage specialized low-rank factorizations across embedding dimensions or attention heads. From the point of view of classical low-rank approximation, these methods are unconventional and raise questions of which objects they really approximate and how to interpret the low-rank behavior of the resulting representations. To answer these questions, this work proposes a generalized view on the weight objects in the self-attention layer and a factorization strategy, which allows us to construct a parameter efficient scheme, called Tucker Attention. Tucker Attention requires an order of magnitude fewer parameters for comparable validation metrics, compared to GQA and MLA, as evaluated in LLM and ViT test cases. Additionally, Tucker Attention~encompasses GQA, MLA, MHA as special cases and is fully compatible with flash-attention and rotary position embeddings (RoPE). This generalization strategy yields insights of the actual ranks achieved by MHA, GQA, and MLA, and further enables simplifications for MLA.

Tucker Attention: A generalization of approximate attention mechanisms

Abstract

The pursuit of reducing the memory footprint of the self-attention mechanism in multi-headed self attention (MHA) spawned a rich portfolio of methods, e.g., group-query attention (GQA) and multi-head latent attention (MLA). The methods leverage specialized low-rank factorizations across embedding dimensions or attention heads. From the point of view of classical low-rank approximation, these methods are unconventional and raise questions of which objects they really approximate and how to interpret the low-rank behavior of the resulting representations. To answer these questions, this work proposes a generalized view on the weight objects in the self-attention layer and a factorization strategy, which allows us to construct a parameter efficient scheme, called Tucker Attention. Tucker Attention requires an order of magnitude fewer parameters for comparable validation metrics, compared to GQA and MLA, as evaluated in LLM and ViT test cases. Additionally, Tucker Attention~encompasses GQA, MLA, MHA as special cases and is fully compatible with flash-attention and rotary position embeddings (RoPE). This generalization strategy yields insights of the actual ranks achieved by MHA, GQA, and MLA, and further enables simplifications for MLA.

Paper Structure

This paper contains 58 sections, 5 theorems, 27 equations, 9 figures, 6 tables.

Key Result

Lemma 2.1

[lemma]le:att_to_ten Assume $D_i = A_i B_i^{\top}$, where $A_i, B_i \in \mathbb{R}^{d_{\rm{model}} \times d_{\rm{H}}}$, and $D_i\in\mathbb{R}^{d_{\rm{model}} \times d_{\rm{model}}}$ are matrices. Further, let $\mathcal{A}, \mathcal{B}\in\mathbb{R}^{{n_{\rm{H}}} \times d_{\rm{model}} \times d_{\rm{H}

Figures (9)

  • Figure 1: Illustration of the pre-softmax attention tensor ${\mathcal{W}}$ under existing factorizations; colored edges denote contraction dimensions. MHA:${\mathcal{W}}$ is formed by contracting $d_{\rm{model}} \times d_{\rm{H}}$ query and key matrices along $d_{\rm{H}}$ (orange) for each of the ${n_{\rm{H}}}$ heads, yielding tensor rank $({n_{\rm{H}}}, d_{\rm{model}}, d_{\rm{model}})$. GQA: Queries follow the same parametrization as MHA, but only $n_{\rm{KV}}$ key matrices are used. During contraction along $d_{\rm{H}}$ (orange), each key matrix is broadcasted to all queries in its head group, yielding rank $({n_{\rm{H}}}, d_{\rm{model}}, n_{\rm{KV}}d_{\rm{H}})$. MLA: Key and query matrices are replaced by down- and up-projections, with up-projections reshaped to ${n_{\rm{H}}} \times d_{\rm{c}} \times d_{\rm{H}}$. Contractions occur along $d_{\rm{c}}^Q, d_{\rm{c}}^K$ (red/blue), and along $d_{\rm{H}}$ per head, giving rank $({n_{\rm{H}}}, d_{\rm{c}}^{\rm{Q}}, d_{\rm{c}}^{\rm{K}})$. At inference, contractions along $d_{\rm{c}}^{\rm{Q}}$ and $d_{\rm{H}}$ are precomputed. The post-softmax tensor $\widetilde{\mathcal{W}}$ is parametrized analogously for MHA and GQA; for MLA, only the value matrix has low-rank factorization while the output matrix remains full-rank.
  • Figure 2: Normalized singular spectrum of pre-softmax head mode, $\mathop{\mathrm{\mathrm{Mat}}}\nolimits_1(\mathcal{W})$, for GPT2 pretraining. See \ref{['fig:gpt_tucker_ranks']} for a detailed description. MLA shares a similar low-rank behavior across heads as Tucker Attention, but is not able to leverage this fact to reduce parameters.
  • Figure 3: ViT32l (left) and ViT14g (right) top-5 validation performance on ImageNet1k for MHA and approximate attention mechanisms over total attention parameter count (top-left is best). GQA is presented with $n_{\rm{KV}}=1$ to $4$ KV heads, where GQA-1 corresponds to MQA. MLA is presented with latent dimension 16-64 with individual kv-weights. Tucker Attention is presented with ranks $[r_{\rm{1}},r_{\rm{2}},r_{\rm{3}}]$. Is it apparent that Tucker Attention requires an order of magnitude fewer parameters for comparable accuracy levels and thus shifts the parameter-accuracy pareto-frontier to the top left.
  • Figure 4: Llama3-1B training cross-entropy loss (left) and final time validation cross-entropy loss (right) on OpenWebtext2 with RoPE in Bf16 floating point accuracy. Approximate attention methods are MLA GQA, and Tucker Attention. Tucker Attention converges well, despite having only a fraction of the trainable parameters of the other methods, and does not exhibits instabilities.
  • Figure 5: Normalized singular spectrum of the transformer layers in GPT2 after training. The singular spectrum was computed via a matricization of the tensor along a single mode; specific information about which mode is given in each subplot. The ribbon plots are calculated as a sample distribution over all twelve transformer layers in GPT2.
  • ...and 4 more figures

Theorems & Definitions (12)

  • Definition 2.1: Tucker Attention
  • Definition 3.1: Latent RoPE for Tucker Attention
  • Lemma 2.1
  • proof
  • Theorem 2.2
  • proof
  • Theorem 2.3
  • proof
  • Theorem 2.4
  • proof
  • ...and 2 more