Table of Contents
Fetching ...

MatryoshkaKV: Adaptive KV Compression via Trainable Orthogonal Projection

Bokai Lin, Zihao Zeng, Zipeng Xiao, Siqi Kou, Tianqi Hou, Xiaofeng Gao, Hao Zhang, Zhijie Deng

TL;DR

This work tackles KV-cache memory bottlenecks in large language models by introducing MatryoshkaKV, a trainable, orthogonal projection-based method to compress the feature dimension of KV caches. Starting from PCA initialization, it uses a Cayley-parameterized orthogonal transform and a Matryoshka training strategy to enable hierarchical, per-head/layer compression with a distillation objective that preserves model outputs. A greedy search then assigns heterogeneous compression rates across layers and heads, balancing performance and memory savings. Empirically, MatryoshkaKV sustains over 90% of baseline accuracy with around 60% average KV-cache compression, and up to 75% in some scenarios, while remaining compatible with LoRA fine-tuning and other KV-cache techniques, demonstrating practical impact for scalable, efficient inference.

Abstract

KV cache has become a de facto technique for the inference of large language models (LLMs), where tensors of shape (layer number, head number, sequence length, feature dimension) are introduced to cache historical information for self-attention. As the size of the model and data grows, the KV cache can quickly become a bottleneck within the system in both storage and memory transfer. To address this, prior studies usually focus on the first three axes of the cache tensors for compression. This paper supplements them, focusing on the feature dimension axis, by utilizing low-rank projection matrices to transform the cache features into spaces with reduced dimensions. We begin by investigating the canonical orthogonal projection method for data compression through principal component analysis (PCA). We observe the issue with PCA projection where significant performance degradation is observed at low compression rates. To bridge the gap, we propose to directly tune the orthogonal projection matrices with a distillation objective using an elaborate Matryoshka training strategy. After training, we adaptively search for the optimal compression rates for various layers and heads given varying compression budgets. Compared to previous works, our method can easily embrace pre-trained LLMs and hold a smooth tradeoff between performance and compression rate. We empirically witness the high data efficiency of our training procedure and find that our method can sustain over 90% performance with an average KV cache compression rate of 60% (and up to 75% in certain extreme scenarios) for popular LLMs like LLaMA2-7B-base and Mistral-7B-v0.3-base.

MatryoshkaKV: Adaptive KV Compression via Trainable Orthogonal Projection

TL;DR

This work tackles KV-cache memory bottlenecks in large language models by introducing MatryoshkaKV, a trainable, orthogonal projection-based method to compress the feature dimension of KV caches. Starting from PCA initialization, it uses a Cayley-parameterized orthogonal transform and a Matryoshka training strategy to enable hierarchical, per-head/layer compression with a distillation objective that preserves model outputs. A greedy search then assigns heterogeneous compression rates across layers and heads, balancing performance and memory savings. Empirically, MatryoshkaKV sustains over 90% of baseline accuracy with around 60% average KV-cache compression, and up to 75% in some scenarios, while remaining compatible with LoRA fine-tuning and other KV-cache techniques, demonstrating practical impact for scalable, efficient inference.

Abstract

KV cache has become a de facto technique for the inference of large language models (LLMs), where tensors of shape (layer number, head number, sequence length, feature dimension) are introduced to cache historical information for self-attention. As the size of the model and data grows, the KV cache can quickly become a bottleneck within the system in both storage and memory transfer. To address this, prior studies usually focus on the first three axes of the cache tensors for compression. This paper supplements them, focusing on the feature dimension axis, by utilizing low-rank projection matrices to transform the cache features into spaces with reduced dimensions. We begin by investigating the canonical orthogonal projection method for data compression through principal component analysis (PCA). We observe the issue with PCA projection where significant performance degradation is observed at low compression rates. To bridge the gap, we propose to directly tune the orthogonal projection matrices with a distillation objective using an elaborate Matryoshka training strategy. After training, we adaptively search for the optimal compression rates for various layers and heads given varying compression budgets. Compared to previous works, our method can easily embrace pre-trained LLMs and hold a smooth tradeoff between performance and compression rate. We empirically witness the high data efficiency of our training procedure and find that our method can sustain over 90% performance with an average KV cache compression rate of 60% (and up to 75% in certain extreme scenarios) for popular LLMs like LLaMA2-7B-base and Mistral-7B-v0.3-base.

Paper Structure

This paper contains 23 sections, 5 equations, 9 figures, 9 tables, 1 algorithm.

Figures (9)

  • Figure 1: Visualization of the feasible compression level for the keys and values in our model distilled from the LLaMA2-7B-base model. We individually leverage samples in ARC-challenge (ARC-C), ARC-easy (ARC-E) arc, and Winogrande (WG) winogrande: to determine the compression level. Lighter colors indicate higher compression levels. As shown, our approach enables the use of various compression strategies for various tasks.
  • Figure 2: Vanilla KV cache vs. the proposed MatryoshkaKV. In particular, we introduce orthogonal projection matrices to reduce the dimension of stored keys and values. We explicitly enforce a hierarchy over the columns of projection matrices so as to concentrate the principal information on the head dimensions and enable the adjustment of compression level according to resource constraints.
  • Figure 3: Evaluation loss of four budgets vs. the number of training samples during 1 epoch of SFT on GSM8K (Left). Evaluation loss of models with and without PCA initialization, using a 50% cache budget, vs. the number of training samples during 4 epochs of SFT on GSM8K (Right).
  • Figure 4: Comparison between PCA and distilled MatryoshkaKV Projections after CPT with and without greedy search for adaptive compression levels. We report average accuracy on datasets mentioned in the experimental setup of Section \ref{['subsection:cptexp']} (Left). Comparison between with and without Matryoshka training strategy and orthogonal constraint after SFT on GSM8K. We report the relative accuracy compared with the LLaMA2-7B-base model fine-tuned with LoRA on GSM8K, utilizing the full KV cache (Right).
  • Figure 5: After obtaining an orthogonal matrix through training, we merge the parameters in this way, reducing the number of matrix multiplications required during inference without incurring any inference time overhead. Truncation can be achieved simply by removing the columns corresponding to $W_{OV}$ , thereby reducing peak memory consumption.
  • ...and 4 more figures