Table of Contents
Fetching ...

CARE: Covariance-Aware and Rank-Enhanced Decomposition for Enabling Multi-Head Latent Attention

Zhongzhu Zhou, Fengxiang Bie, Ziyan Chen, Zhenyu Zhang, Yibo Yang, Junxiong Wang, Ben Athiwaratkun, Xiaoxia Wu, Shuaiwen Leon Song

Abstract

Converting pretrained attention modules such as grouped-query attention (GQA) into multi-head latent attention (MLA) can improve expressivity without increasing KV-cache cost, making it attractive for efficient inference. However, many practical conversion baselines rely on weight-only low-rank approximations (e.g., SVD-style initializations) and uniform rank allocation. They focus on minimizing the difference between weight matrices rather than on how those weights affect input activations, ignore the covariance structure of activations, and enforce uniform rank across layers, causing activation drift and degraded attention fidelity. To address these issues, we propose CARE, a Covariance-Aware, Rank-Enhanced MLA conversion pipeline under a fixed KV width. CARE introduces three key steps: (i) activation-preserving factorization, which aligns the approximation with the actual input activations rather than just the weights; (ii) adjusted-rank allocation, which spreads a fixed KV budget across layers by giving more capacity to layers that need it most; and (iii) KV-parity mapping, which reparameterizes the converted K and V to fit the MLA format while keeping the KV-cache size unchanged. Our method outperforms a uniform-rank SVD baseline on Qwen3-4B/30B-A3B-Instruct-2507 and Llama-3.1-8B/70B-Instruct, reducing one-shot perplexity by up to 215x and improving mean accuracy by up to 1.70x at matched KV budgets. With a brief post-SVD healing fine-tune, we fully recover the original model's accuracy.

CARE: Covariance-Aware and Rank-Enhanced Decomposition for Enabling Multi-Head Latent Attention

Abstract

Converting pretrained attention modules such as grouped-query attention (GQA) into multi-head latent attention (MLA) can improve expressivity without increasing KV-cache cost, making it attractive for efficient inference. However, many practical conversion baselines rely on weight-only low-rank approximations (e.g., SVD-style initializations) and uniform rank allocation. They focus on minimizing the difference between weight matrices rather than on how those weights affect input activations, ignore the covariance structure of activations, and enforce uniform rank across layers, causing activation drift and degraded attention fidelity. To address these issues, we propose CARE, a Covariance-Aware, Rank-Enhanced MLA conversion pipeline under a fixed KV width. CARE introduces three key steps: (i) activation-preserving factorization, which aligns the approximation with the actual input activations rather than just the weights; (ii) adjusted-rank allocation, which spreads a fixed KV budget across layers by giving more capacity to layers that need it most; and (iii) KV-parity mapping, which reparameterizes the converted K and V to fit the MLA format while keeping the KV-cache size unchanged. Our method outperforms a uniform-rank SVD baseline on Qwen3-4B/30B-A3B-Instruct-2507 and Llama-3.1-8B/70B-Instruct, reducing one-shot perplexity by up to 215x and improving mean accuracy by up to 1.70x at matched KV budgets. With a brief post-SVD healing fine-tune, we fully recover the original model's accuracy.
Paper Structure (69 sections, 25 equations, 11 figures, 6 tables)

This paper contains 69 sections, 25 equations, 11 figures, 6 tables.

Figures (11)

  • Figure 1: (a) Naive MLA transfer: jointly factorize $W_K^{(g)}$ and $W_V^{(g)}$ by SVD and truncate to a uniform per-layer rank, optimizing $\|W-\hat{W}\|_F$ while ignoring layerwise heterogeneity. (b) CARE: estimate activation covariance $C$, factorize $\sqrt{C}W$, unwhiten via $\sqrt{C}^{-1}$ to initialize MLA factors, and use the singular spectrum of $\sqrt{C}W$ for global dynamicrank scheduling under KV parity. This preserves activation geometry and yields a stronger one-shot initialization with less healing.
  • Figure 2: (a) Accuracy under 50% rank reduction applied one layer at a time in DeepSeek-V2-Lite, measured on ARC Challenge and MMLU. Sensitivity is strongly layer-dependent. (b) WikiText perplexity under grouped truncation of GQA attention in Llama-3-8B (layers 30--32). Singular-value groups are ordered by magnitude; the resulting non-monotone degradation shows that singular values alone are an imperfect proxy for MLA conversion quality.
  • Figure 3: Covariance-aware rank profiles across calibration corpora (Alpaca, WikiText2, PTB, C4) for Llama-3.1-8B-Instruct at target ranks 64, 128, 256, and 512. Across all target budgets, both $W_K$ and $W_V$ show a depth-dependent increase—small in early layers, rising through mid layers—with stronger late-layer growth for $W_V$. The consistency across corpora suggests a model-intrinsic trend.
  • Figure 4: One-shot accuracy versus calibration samples and sequence length across eight benchmarks. Sequence length is fixed at 256 unless otherwise noted; the red curve denotes fixed 256 samples with varying sequence length. OOM occurs beyond length $=512$ during covariance computation.
  • Figure I.1: Needle-in-a-Haystack retrieval heatmaps for Llama-3.1-8B-Instruct under matched KV budgets. The figure assembles five panel-wise heatmaps comparing CARE variants with different calibration settings against Palu(SVD) across context lengths and needle depths. Greener cells indicate stronger retrieval accuracy.
  • ...and 6 more figures

Theorems & Definitions (2)

  • proof
  • proof