Table of Contents
Fetching ...

Post-Training Sparse Attention with Double Sparsity

Shuo Yang, Ying Sheng, Joseph E. Gonzalez, Ion Stoica, Lianmin Zheng

TL;DR

This work tackles the KV-cache bottleneck in LLM inference by introducing Double Sparsity, a post-training sparse attention method that jointly exploits token sparsity and channel sparsity. It leverages offline calibration to identify static heavy channels and a label cache to ensure contiguous memory access, enabling efficient approximate attention without losing full KV information. A companion Double Sparsity-Offload further reduces GPU memory by offloading KV caches to CPU with a double-buffering scheme, achieving substantial end-to-end and decoding-speed advantages, especially on long contexts. The approach demonstrates up to 14.1× attention speedups and up to 16.3× decoding throughput with offload, while preserving accuracy across diverse tasks and models, and is accompanied by open-source code for reproducibility and practical integration.

Abstract

The inference process for large language models is slow and memory-intensive, with one of the most critical bottlenecks being excessive Key-Value (KV) cache accesses. This paper introduces "Double Sparsity," a novel post-training sparse attention technique designed to alleviate this bottleneck by reducing KV cache access. Double Sparsity combines token sparsity, which focuses on utilizing only the important tokens for computing self-attention, with channel sparsity, an approach that uses important feature channels for identifying important tokens. Our key insight is that the pattern of channel sparsity is relatively static, allowing us to use offline calibration to make it efficient at runtime, thereby enabling accurate and efficient identification of important tokens. Moreover, this method can be combined with offloading to achieve significant memory usage reduction. Experimental results demonstrate that Double Sparsity can achieve $\frac{1}{16}$ token and channel sparsity with minimal impact on accuracy across various tasks, including wiki-2 perplexity, key-value retrieval, and long context benchmarks with models including Llama-2-7B, Llama-2-70B, and Mixtral-8x7B. It brings up to a 14.1$\times$ acceleration in attention operations and a 1.9$\times$ improvement in end-to-end inference on GPUs. With offloading, it achieves a decoding speed acceleration of 16.3$\times$ compared to state-of-the-art solutions at a sequence length of 256K. Our code is publicly available at https://github.com/andy-yang-1/DoubleSparse.

Post-Training Sparse Attention with Double Sparsity

TL;DR

This work tackles the KV-cache bottleneck in LLM inference by introducing Double Sparsity, a post-training sparse attention method that jointly exploits token sparsity and channel sparsity. It leverages offline calibration to identify static heavy channels and a label cache to ensure contiguous memory access, enabling efficient approximate attention without losing full KV information. A companion Double Sparsity-Offload further reduces GPU memory by offloading KV caches to CPU with a double-buffering scheme, achieving substantial end-to-end and decoding-speed advantages, especially on long contexts. The approach demonstrates up to 14.1× attention speedups and up to 16.3× decoding throughput with offload, while preserving accuracy across diverse tasks and models, and is accompanied by open-source code for reproducibility and practical integration.

Abstract

The inference process for large language models is slow and memory-intensive, with one of the most critical bottlenecks being excessive Key-Value (KV) cache accesses. This paper introduces "Double Sparsity," a novel post-training sparse attention technique designed to alleviate this bottleneck by reducing KV cache access. Double Sparsity combines token sparsity, which focuses on utilizing only the important tokens for computing self-attention, with channel sparsity, an approach that uses important feature channels for identifying important tokens. Our key insight is that the pattern of channel sparsity is relatively static, allowing us to use offline calibration to make it efficient at runtime, thereby enabling accurate and efficient identification of important tokens. Moreover, this method can be combined with offloading to achieve significant memory usage reduction. Experimental results demonstrate that Double Sparsity can achieve token and channel sparsity with minimal impact on accuracy across various tasks, including wiki-2 perplexity, key-value retrieval, and long context benchmarks with models including Llama-2-7B, Llama-2-70B, and Mixtral-8x7B. It brings up to a 14.1 acceleration in attention operations and a 1.9 improvement in end-to-end inference on GPUs. With offloading, it achieves a decoding speed acceleration of 16.3 compared to state-of-the-art solutions at a sequence length of 256K. Our code is publicly available at https://github.com/andy-yang-1/DoubleSparse.
Paper Structure (35 sections, 1 equation, 9 figures, 4 tables, 1 algorithm)

This paper contains 35 sections, 1 equation, 9 figures, 4 tables, 1 algorithm.

Figures (9)

  • Figure 1: Double Sparsity Decode
  • Figure 2: Decoding process of Double Sparsity.
  • Figure 3: Performance of different techniques across various sparsity levels for long context benchmarks. 'DS' and 'DS-O' refer to Double Sparsity and Double Sparsity-Offloading. 'Stream' refers to Streaming-LLM.
  • Figure 4: Retrieval accuracy across various sparsity levels and context lengths. 'DS' and 'DS-O' refer to Double Sparsity and Double Sparsity-Offloading. 'Stream' refers to Streaming-LLM. 'RTN' refers to RTN Quantization.
  • Figure 5: Latency and speedup of Double Sparsity Attention at various batch sizes and sequence lengths. 'DS' indicates double sparsity attention. 'Flash' indicates the 'scaled_dot_product_attention', which is the fastest of FlashAttention-2 and Memory-Efficient Attention.
  • ...and 4 more figures