KV-weights are all you need for skipless transformers
Nils Graef
TL;DR
The paper tackles the problem that skipless transformers, while reducing weights, previously applied mainly to MHA and not to MQA or GQA used by many LLMs. It proposes mathematically equivalent variants that merge post-attention and query/keys/values/projections into surrounding layers, under invertibility assumptions, achieving about $15$–$16\%$ weight reductions per block and corresponding speedups. Key contributions include explicit formulas for new weight matrices, parallel and non-parallel skipless variants across MHA, MQA, and GQA, and concrete model configurations validated by code demonstrating numerical equivalence and matrix invertibility. This approach offers a practical path toward more memory- and compute-efficient inference for large LLMs and invites future work on retrofitting to models with normalization and skip connections, aligned with a broader trend toward architectural parsimony.
Abstract
He and Hofmann (arXiv:2311.01906) detailed a skipless transformer without the V and P (post-attention projection) linear layers, which reduces the total number of weights. However, this scheme is only applicable to MHA (multi-head attention), but not for MQA (multi-query attention) and GQA (grouped-query attention). The latter schemes are used by many popular LLMs such as Llama 2, Mistral, Mixtral, PaLM, and Gemma. Therefore, this micro-paper proposes mathematically equivalent versions that are suitable for MQA and GQA. For example, removing Q and P from a skipless version of Mistral-7B would remove 15% of its weights (and thus reduce its compute and memory complexity). Watch our explainer video https://youtu.be/Tx_lMpphd2g and see https://github.com/OpenMachine-ai/transformer-tricks for code and more transformer tricks.
