Table of Contents
Fetching ...

KV-weights are all you need for skipless transformers

Nils Graef

TL;DR

The paper tackles the problem that skipless transformers, while reducing weights, previously applied mainly to MHA and not to MQA or GQA used by many LLMs. It proposes mathematically equivalent variants that merge post-attention and query/keys/values/projections into surrounding layers, under invertibility assumptions, achieving about $15$–$16\%$ weight reductions per block and corresponding speedups. Key contributions include explicit formulas for new weight matrices, parallel and non-parallel skipless variants across MHA, MQA, and GQA, and concrete model configurations validated by code demonstrating numerical equivalence and matrix invertibility. This approach offers a practical path toward more memory- and compute-efficient inference for large LLMs and invites future work on retrofitting to models with normalization and skip connections, aligned with a broader trend toward architectural parsimony.

Abstract

He and Hofmann (arXiv:2311.01906) detailed a skipless transformer without the V and P (post-attention projection) linear layers, which reduces the total number of weights. However, this scheme is only applicable to MHA (multi-head attention), but not for MQA (multi-query attention) and GQA (grouped-query attention). The latter schemes are used by many popular LLMs such as Llama 2, Mistral, Mixtral, PaLM, and Gemma. Therefore, this micro-paper proposes mathematically equivalent versions that are suitable for MQA and GQA. For example, removing Q and P from a skipless version of Mistral-7B would remove 15% of its weights (and thus reduce its compute and memory complexity). Watch our explainer video https://youtu.be/Tx_lMpphd2g and see https://github.com/OpenMachine-ai/transformer-tricks for code and more transformer tricks.

KV-weights are all you need for skipless transformers

TL;DR

The paper tackles the problem that skipless transformers, while reducing weights, previously applied mainly to MHA and not to MQA or GQA used by many LLMs. It proposes mathematically equivalent variants that merge post-attention and query/keys/values/projections into surrounding layers, under invertibility assumptions, achieving about weight reductions per block and corresponding speedups. Key contributions include explicit formulas for new weight matrices, parallel and non-parallel skipless variants across MHA, MQA, and GQA, and concrete model configurations validated by code demonstrating numerical equivalence and matrix invertibility. This approach offers a practical path toward more memory- and compute-efficient inference for large LLMs and invites future work on retrofitting to models with normalization and skip connections, aligned with a broader trend toward architectural parsimony.

Abstract

He and Hofmann (arXiv:2311.01906) detailed a skipless transformer without the V and P (post-attention projection) linear layers, which reduces the total number of weights. However, this scheme is only applicable to MHA (multi-head attention), but not for MQA (multi-query attention) and GQA (grouped-query attention). The latter schemes are used by many popular LLMs such as Llama 2, Mistral, Mixtral, PaLM, and Gemma. Therefore, this micro-paper proposes mathematically equivalent versions that are suitable for MQA and GQA. For example, removing Q and P from a skipless version of Mistral-7B would remove 15% of its weights (and thus reduce its compute and memory complexity). Watch our explainer video https://youtu.be/Tx_lMpphd2g and see https://github.com/OpenMachine-ai/transformer-tricks for code and more transformer tricks.
Paper Structure (6 sections, 4 figures, 1 table)

This paper contains 6 sections, 4 figures, 1 table.

Figures (4)

  • Figure 1: (a) Skipless vanilla transformer; equivalent versions with (b) Q and P merged into the FFN (feedforward network); (c) K and P merged into FFN; (d) V and P merged into FFN. $\mathbf{M}_i^*, \mathbf{Q}_i^*, \mathbf{K}_i^*, \mathbf{V}_i^*, \mathbf{O}_{i-1}^*$ are defined in table \ref{['tab1']}.
  • Figure 2: (a) Merging P and M; (b) eliminating Q; (c) eliminating K; (d) eliminating V.
  • Figure 3: Parallel skipless transformers (a) without Q and P; (b) without K and P; (c) without V and P.
  • Figure 4: (a) Transformer block without Q and P; (b) version with parallel attention / FFN.