Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs
Tao Ji, Bin Guo, Yuanbin Wu, Qipeng Guo, Lixing Shen, Zhan Chen, Xipeng Qiu, Qi Zhang, Tao Gui
TL;DR
This work tackles the costly KV-cache memory bottleneck in autoregressive LLMs by proposing MHA2MLA, a data-efficient fine-tuning framework that migrates pre-trained MHA models to the MLA architecture with minimal data (about 0.6%–1%). It combines two core ideas: partial-RoPE to selectively remove RoPE from certain dimensions and low-rank SVD-based projections to compress the NoPE components, yielding a latent KV cache that preserves most pre-trained knowledge. Empirical results across model scales (135M–13B) and tasks show near-baseline performance with substantial KV-cache reductions (up to around 97% in some settings) and favorable compatibility with KV-cache quantization, including mix-and-match with Int4/4-bit schemes. The findings demonstrate practical pathway for deploying resource-efficient LLMs without retraining from scratch, while maintaining commonsense reasoning and long-context capabilities; future work includes broader verification on larger LLMs and further parameter-efficient fine-tuning refinements.
Abstract
Multi-head Latent Attention (MLA) is an innovative architecture proposed by DeepSeek, designed to ensure efficient and economical inference by significantly compressing the Key-Value (KV) cache into a latent vector. Compared to MLA, standard LLMs employing Multi-Head Attention (MHA) and its variants such as Grouped-Query Attention (GQA) exhibit significant cost disadvantages. Enabling well-trained LLMs (e.g., Llama) to rapidly adapt to MLA without pre-training from scratch is both meaningful and challenging. This paper proposes the first data-efficient fine-tuning method for transitioning from MHA to MLA (MHA2MLA), which includes two key components: for partial-RoPE, we remove RoPE from dimensions of queries and keys that contribute less to the attention scores, for low-rank approximation, we introduce joint SVD approximations based on the pre-trained parameters of keys and values. These carefully designed strategies enable MHA2MLA to recover performance using only a small fraction (0.3% to 0.6%) of the data, significantly reducing inference costs while seamlessly integrating with compression techniques such as KV cache quantization. For example, the KV cache size of Llama2-7B is reduced by 92.19%, with only a 0.5% drop in LongBench performance.
