Table of Contents
Fetching ...

Kascade: A Practical Sparse Attention Method for Long-Context LLM Inference

Dhruv Deshmukh, Saurabh Goyal, Nipun Kwatra, Ramachandran Ramjee

TL;DR

Kascade tackles the latency bottleneck of attention in long-context LLM inference by introducing a training-free sparse attention method that computes exact Top-k indices on a small set of automatically chosen anchor layers and reuses them in subsequent layers. It leverages cross-layer similarity and head-aware remapping, along with tile-based pooling and efficient TileLang kernels, to accelerate both prefill and decode without retraining. The approach yields up to 4.1x decode and 2.2x prefill speedups on H100 while preserving near-dense accuracy on LongBench and AIME-24, outperforming other training-free sparse methods at similar sparsity. By automating anchor-layer selection and incorporating per-head Top-k reuse, Kascade offers a practical, deployable solution for accelerating long-context inference across diverse models.

Abstract

Attention is the dominant source of latency during long-context LLM inference, an increasingly popular workload with reasoning models and RAG. We propose Kascade, a training-free sparse attention method that leverages known observations such as 1) post-softmax attention is intrinsically sparse, and 2) the identity of high-weight keys is stable across nearby layers. Kascade computes exact Top-k indices in a small set of anchor layers, then reuses those indices in intermediate reuse layers. The anchor layers are selected algorithmically, via a dynamic-programming objective that maximizes cross-layer similarity over a development set, allowing easy deployment across models. The method incorporates efficient implementation constraints (e.g. tile-level operations), across both prefill and decode attention. The Top-k selection and reuse in Kascade is head-aware and we show in our experiments that this is critical for high accuracy. Kascade achieves up to 4.1x speedup in decode attention and 2.2x speedup in prefill attention over FlashAttention-3 baseline on H100 GPUs while closely matching dense attention accuracy on long-context benchmarks such as LongBench and AIME-24.

Kascade: A Practical Sparse Attention Method for Long-Context LLM Inference

TL;DR

Kascade tackles the latency bottleneck of attention in long-context LLM inference by introducing a training-free sparse attention method that computes exact Top-k indices on a small set of automatically chosen anchor layers and reuses them in subsequent layers. It leverages cross-layer similarity and head-aware remapping, along with tile-based pooling and efficient TileLang kernels, to accelerate both prefill and decode without retraining. The approach yields up to 4.1x decode and 2.2x prefill speedups on H100 while preserving near-dense accuracy on LongBench and AIME-24, outperforming other training-free sparse methods at similar sparsity. By automating anchor-layer selection and incorporating per-head Top-k reuse, Kascade offers a practical, deployable solution for accelerating long-context inference across diverse models.

Abstract

Attention is the dominant source of latency during long-context LLM inference, an increasingly popular workload with reasoning models and RAG. We propose Kascade, a training-free sparse attention method that leverages known observations such as 1) post-softmax attention is intrinsically sparse, and 2) the identity of high-weight keys is stable across nearby layers. Kascade computes exact Top-k indices in a small set of anchor layers, then reuses those indices in intermediate reuse layers. The anchor layers are selected algorithmically, via a dynamic-programming objective that maximizes cross-layer similarity over a development set, allowing easy deployment across models. The method incorporates efficient implementation constraints (e.g. tile-level operations), across both prefill and decode attention. The Top-k selection and reuse in Kascade is head-aware and we show in our experiments that this is critical for high accuracy. Kascade achieves up to 4.1x speedup in decode attention and 2.2x speedup in prefill attention over FlashAttention-3 baseline on H100 GPUs while closely matching dense attention accuracy on long-context benchmarks such as LongBench and AIME-24.

Paper Structure

This paper contains 18 sections, 6 equations, 8 figures, 3 tables, 1 algorithm.

Figures (8)

  • Figure 1: Attention weight covered by top 256 keys across layers and heads. Except for layer 0 rest of layers have high sparsity across majority of the heads. Model=Llama-3.1-8b-Instruct, Dataset=MuSiQue.
  • Figure 2: Oracle Top-$k$ attention results with varying Top-$k$ percentage. With layer 0 doing full attention, Oracle Top-$k$ matches baseline score even with Top-$k$ as 5%. Model=Llama-3.1-8b-Instruct, Dataset=2WikiMultihopQA.
  • Figure 3: Cross layer similarity using top 256 keys. Bright cell indicates that Top-$k$ indices of layer i cover high fraction of attention covered by Top-$k$ indices of layer j itself. Model=Llama-3.1-8b-Instruct, Dataset=MuSiQue
  • Figure 4: Importance scores of attention blocks of all layers. Deeper layers have lower importance than the initial layers. Layer 0 has highest importance. Model=Llama-3.1-8b-Instruct, Dataset=MuSiQue.
  • Figure 5: Comparison of Top-$k$ attention accuracy, when pooling with Pre vs Post Softmax attention scores, across different tile sizes. Top-$k$ percentage here is $10\%$. The smallest tile size is $4$ where only the queries corresponding to the same key head are pooled. Post Softmax is more robust to changes in tile size and does consistently well across all tile sizes. Model=Llama-3.1-8b-Instruct, Dataset=MuSiQue
  • ...and 3 more figures