Table of Contents
Fetching ...

QKV Projections Require a Fraction of Their Memory

Malik Khalaf, Yara Shamshoum, Nitzan Hodos, Yuval Sieradzki, Assaf Schuster

TL;DR

Point-Approximate Matrix Multiplication (PAMM) is proposed, a novel tensor compression technique that compresses the activations of the activations of the Q,K,V projections in attention layers by a factor of up to $\times 512$, effectively erasing their memory footprint, while achieving similar or better final perplexity.

Abstract

The Multi-Head Attention mechanism is central to LLM operation, and multiple works target its compute and memory efficiency during training. While most works focus on approximating the scaled dot product, the memory consumption of the linear projections that compute the $Q$, $K$, and $V$ tensors from the input $x$ is often overlooked. To address this, we propose Point-Approximate Matrix Multiplication (PAMM), a novel tensor compression technique that compresses the activations of the $Q,K,V$ projections in attention layers by a factor of up to $\times 512$, effectively erasing their memory footprint, while achieving similar or better final perplexity. PAMM is fully composable with efficient attention techniques such as FlashAttention, making it a practical and complementary method for memory-efficient LLM training.

QKV Projections Require a Fraction of Their Memory

TL;DR

Point-Approximate Matrix Multiplication (PAMM) is proposed, a novel tensor compression technique that compresses the activations of the activations of the Q,K,V projections in attention layers by a factor of up to , effectively erasing their memory footprint, while achieving similar or better final perplexity.

Abstract

The Multi-Head Attention mechanism is central to LLM operation, and multiple works target its compute and memory efficiency during training. While most works focus on approximating the scaled dot product, the memory consumption of the linear projections that compute the , , and tensors from the input is often overlooked. To address this, we propose Point-Approximate Matrix Multiplication (PAMM), a novel tensor compression technique that compresses the activations of the projections in attention layers by a factor of up to , effectively erasing their memory footprint, while achieving similar or better final perplexity. PAMM is fully composable with efficient attention techniques such as FlashAttention, making it a practical and complementary method for memory-efficient LLM training.

Paper Structure

This paper contains 38 sections, 2 theorems, 15 equations, 8 figures, 10 tables, 3 algorithms.

Key Result

Lemma 1

The generator representing $A_i$ is the one with the highest absolute cosine similarity to $A_i$, i.e. $f(i)=\mathop{\mathrm{arg\,max}}\limits_j{\vert\textup{csim}(A_i,C_j)\vert}$.

Figures (8)

  • Figure 1: Illustration of Point-Approximate Matrix Multiplication (PAMM). PAMM approximates the matrix multiplication $O=A^\top B$ for $A\in\mathbb{R}^{b\times n}$ and $B\in\mathbb{R}^{b\times m}$ in two stages. First, each row $A_i$ is represented by $\Tilde{A}_i$, defined as the closest point to $A_i$ on the line spanned by the best representative $C_j$. Instead of storing the full matrix $A$, PAMM keeps only a small set of $k$ generators $C\in \mathbb{R}^{k\times n}$, together with an assignment mapping $f\in\mathbb{R}^b$ and scaling factors $\alpha\in\mathbb{R}^b$. Using these, PAMM approximates the product by first contracting $B$ into $\Tilde{B}\in\mathbb{R}^{k\times m}$, and then calculating $\Tilde{O}=C^\top\Tilde{B}$ (equivalently, $\tilde{O} = \tilde{A}^\top B$). In our setting, PAMM is applied to approximate the gradient of the projection matrices $Q,K,V$ during backpropagation, e.g. $\widetilde{\nabla W}_Q=\Tilde{X}^\top\nabla Q$. Since the number of generating points $\{C_j\}_{j=1}^k$ is typically very small, the memory required to store $X$ is drastically reduced.
  • Figure 2: Overview of training with PAMM. In a standard linear layer, backpropagation multiplies the upstream gradient by the input $X$ to compute the gradient with respect to the parameters $\nabla W=X^\top\cdot\nabla Z$. In our method, for the $W_Q,W_K,W_V$ projections, we store a compressed version of $X$ and approximate the gradient-input matrix multiplication with PAMM. Here, $b$ denotes the total number of tokens in a training batch. Pseudocode is provided in Appendix \ref{['app:algorithms']}.
  • Figure 3: Model perplexity and attention memory when pretraining LLaMA models on C4 with PAMM. PAMM achieves massive memory reductions while maintaining or even increasing perplexity compared to the baseline. Results averaged over k=3 runs per configuration.
  • Figure 4: (a) Compression technique comparison when applied to $Q,K,V$ projections of LLaMA-60M, with different compression rates. PAMM significantly outperforms other compression methods and can utilize a very low $r$. (b) Effect of $\varepsilon$ on performance of LLaMA-60M with PAMM. When $\varepsilon=0$, PAMM is equivalent to Uniform-CRS; when $\varepsilon=\infty$, PAMM doesn't enforce the neighborhood condition. We can see that choosing $\varepsilon=\infty$ is the best option for all compression rates checked.
  • Figure 5: Two-dimensional PCA visualization of PAMM approximation.(a) rows of the input tensor of the $K$ projection of layer 3 in LLaMA-60M, $X_{K_3}$. The rows are projected using PCA and colored by their assigned generating point, i.e. by $f(i)$, using PAMM with $\varepsilon=\infty$. (b) The representative rows $\tilde{X}_{K_3}$ shown in the same PCA space, colored by $f(i)$. We can see that our method approximately clusters the input data, and transforms clusters into lines.
  • ...and 3 more figures

Theorems & Definitions (4)

  • Lemma 1
  • Lemma 2: $k$ bound under uniform sampling
  • proof
  • proof