Table of Contents
Fetching ...

Multi-Head Low-Rank Attention

Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo

TL;DR

This work proposes Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding and achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8$\times$ decoding speedup over MLA.

Abstract

Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8$\times$ decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.

Multi-Head Low-Rank Attention

TL;DR

This work proposes Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding and achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8 decoding speedup over MLA.

Abstract

Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8 decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.
Paper Structure (61 sections, 1 theorem, 85 equations, 9 figures, 40 tables)

This paper contains 61 sections, 1 theorem, 85 equations, 9 figures, 40 tables.

Key Result

Theorem 1

Given two tokens with query $\bm{q}$ and key $\bm{k}$ at positions $t_q$ and $t_k$, respectively, let $\operatorname{RoPE}\left(\bm{q}, t_q\right)$ and $\operatorname{RoPE}\left(\bm{k}, t_k\right)$ denote the RoPE-encoded vectors. We show that translating both positions by an offset $s$ leaves the i Equivalently, for the attention-score matrix $\phi(\bm{X})\in\mathbb{R}^{n\times n}$ induced by RoP

Figures (9)

  • Figure 1: Loss difference between $\mathcal{N}(0, \sigma=0.02)$ and zero initialization, calculated by subtracting the loss of the latter from the former.
  • Figure 2: Loss difference between models without and with scaling, calculated by subtracting the loss of the latter from the former.
  • Figure 3: Loss difference between models with and without double heads, calculated by subtracting the loss of the latter from the former.
  • Figure 4: Loss difference between models without and with gating, calculated by subtracting the loss of the latter from the former.
  • Figure 5: Decoding latency (lower is better) versus sequence length (batch=1) for GQA, MLA, GLA-2, and MLRA-4.
  • ...and 4 more figures

Theorems & Definitions (8)

  • Remark 1
  • Theorem 1
  • proof
  • Remark 2
  • Remark 3
  • Remark 4
  • Remark 5
  • Remark 6