Table of Contents
Fetching ...

Hardware-Efficient Attention for Fast Decoding

Ted Zadouri, Hubert Strauss, Tri Dao

TL;DR

The paper tackles decoding latency in large language models caused by KV cache transfers by introducing hardware-aware attention variants. It presents GTA, which ties KV states to reduce cache reads, and GLA, which uses small latent heads to boost arithmetic intensity and enable scalable tensor-parallelism with minimal KV duplication. Empirical results show GTA and GLA preserve or improve model quality across scales while delivering substantial speedups in decoding and online serving, with up to 2x reductions in end-to-end latency and up to 2x throughput gains. These methods, supported by low-level system optimizations and open-source kernels, offer practical paths to efficient, scalable inference on modern accelerators.

Abstract

LLM decoding is bottlenecked for large batches and long contexts by loading the key-value (KV) cache from high-bandwidth memory, which inflates per-token latency, while the sequential nature of decoding limits parallelism. We analyze the interplay among arithmetic intensity, parallelization, and model quality and question whether current architectures fully exploit modern hardware. This work redesigns attention to perform more computation per byte loaded from memory to maximize hardware efficiency without trading off parallel scalability. We first propose Grouped-Tied Attention (GTA), a simple variant that combines and reuses key and value states, reducing memory transfers without compromising model quality. We then introduce Grouped Latent Attention (GLA), a parallel-friendly latent attention paired with low-level optimizations for fast decoding while maintaining high model quality. Experiments show that GTA matches Grouped-Query Attention (GQA) quality while using roughly half the KV cache and that GLA matches Multi-head Latent Attention (MLA) and is easier to shard. Our optimized GLA kernel is up to 2$\times$ faster than FlashMLA, for example, in a speculative decoding setting when the query length exceeds one. Furthermore, by fetching a smaller KV cache per device, GLA reduces end-to-end latency and increases throughput in online serving benchmarks by up to 2$\times$.

Hardware-Efficient Attention for Fast Decoding

TL;DR

The paper tackles decoding latency in large language models caused by KV cache transfers by introducing hardware-aware attention variants. It presents GTA, which ties KV states to reduce cache reads, and GLA, which uses small latent heads to boost arithmetic intensity and enable scalable tensor-parallelism with minimal KV duplication. Empirical results show GTA and GLA preserve or improve model quality across scales while delivering substantial speedups in decoding and online serving, with up to 2x reductions in end-to-end latency and up to 2x throughput gains. These methods, supported by low-level system optimizations and open-source kernels, offer practical paths to efficient, scalable inference on modern accelerators.

Abstract

LLM decoding is bottlenecked for large batches and long contexts by loading the key-value (KV) cache from high-bandwidth memory, which inflates per-token latency, while the sequential nature of decoding limits parallelism. We analyze the interplay among arithmetic intensity, parallelization, and model quality and question whether current architectures fully exploit modern hardware. This work redesigns attention to perform more computation per byte loaded from memory to maximize hardware efficiency without trading off parallel scalability. We first propose Grouped-Tied Attention (GTA), a simple variant that combines and reuses key and value states, reducing memory transfers without compromising model quality. We then introduce Grouped Latent Attention (GLA), a parallel-friendly latent attention paired with low-level optimizations for fast decoding while maintaining high model quality. Experiments show that GTA matches Grouped-Query Attention (GQA) quality while using roughly half the KV cache and that GLA matches Multi-head Latent Attention (MLA) and is easier to shard. Our optimized GLA kernel is up to 2 faster than FlashMLA, for example, in a speculative decoding setting when the query length exceeds one. Furthermore, by fetching a smaller KV cache per device, GLA reduces end-to-end latency and increases throughput in online serving benchmarks by up to 2.

Paper Structure

This paper contains 50 sections, 7 equations, 15 figures, 45 tables.

Figures (15)

  • Figure 1: Memory-loading schematics during decoding of MLA (Left) and GLA-2 (Right) illustrate reduced data movement and higher arithmetic intensity, achieving more FLOPs per byte accessed and easing the memory-bound bottleneck. In MLA, single latent head $c^{KV}$ with $d_c = 4d_h$ is loaded once from HBM to SRAM and reused as $K$ and $V$ for every query head $\sigma(QK^{\top})V$. In GLA-2, two latent heads, each with $d_c = 2 d_h$, are likewise loaded once and reused as $K$ and $V$ for every query in their groups, eliminating or mitigating cache duplication when queries are sharded across devices.
  • Figure 2: Overview of Grouped-Tied Attention (GTA). A single projection produces a tied KV state that serves as both key and value. The full tied KV dimension is used as the value. For the keys, half of the key dimension comes from the tied KV vector (no positional encoding applied), and the other half comes from a separate single-head projection (where RoPE is applied); this separate half is broadcast to all heads in the group and concatenated with the tied KV half. GTA roughly doubles the arithmetic intensity and halves the KV cache size relative to GQA with the same number of groups.
  • Figure 3: Roofline analysis of BF16 decoding on a single H100 80GB SXM5. In this figure only, the numeric suffix (e.g., GQA-128) indicates the number of query heads $h_q$; elsewhere in the paper, it denotes $h_{kv}$. Left, $L_q{=}1$: With $h_q{=}128$, MLA attains an arithmetic intensity of $\sim 2\cdot h_q{=}256$, near the compute roof of $\sim295$ FLOPs/byte of H100, whereas GLA–128 with two latent heads remains on the I/O roof with arithmetic intensity of $\sim h_q{=}128$ similar to MQA. Right, $L_q{=}2$: e.g., in speculative decoding setting when query length is 2, for $h_q:128$ climbs beyond the roof and becomes compute bound, while GLA with two latent heads, sits at the inflection point, and can run up to $2\times$ faster.
  • Figure 4: Left: Decoding speed of MLA and GLA on H100 80GB SMX5 GPU (theoretical max BF16 compute 989 TFLOPS/s and memory 3350 GB/s), for query length 1 where MLA is close to being bottlenecked by compute (reaching 610 TFLOPS/s) while GLA has not yet saturated compute (360 TFLOPS/s). Right: Output throughput (higher better) for 64 concurrent requests for live server benchmark where GLA outperforms MLA under identical parallelism scheme. Also, GLA-8 with pure TP=8 outperforms MLA with a hybrid of TP and DP. The prefill/decode sequence length is 8192/4096 respectively.
  • Figure 5: Output throughput (higher is better) under live server benchmark. Left: For 16 concurrent requests for long-context prefill 32K/64K with 4K decode length, GLA-8 with TP=8 outperforms MLA with a hybrid of TP and DP across eight GPUs. Right: With 16 concurrent requests where prefill length is uniformly sampled up to 131K tokens and decode length up to 4K tokens, GLA-8 with TP=8 delivers $2.7\times$ higher throughput (higher is better) than MLA with a hybrid of (TP=2, DP=4), where DP is employed to mitigate the KV cache duplication of MLA.
  • ...and 10 more figures