Table of Contents
Fetching ...

TidalDecode: Fast and Accurate LLM Decoding with Position Persistent Sparse Attention

Lijie Yang, Zhihao Zhang, Zhuofu Chen, Zikun Li, Zhihao Jia

TL;DR

TidalDecode addresses the memory and latency bottlenecks of decoding long-context LLMs by introducing Position Persistent Sparse Attention (PPSA) and a KV cache correction mechanism. It leverages overlap of high-attention tokens across consecutive layers to reuse selected token sets, using two token-selection layers to identify top-k tokens and performing PPSA on the rest, thereby reducing token-selection overhead without sacrificing quality. Empirical results across multiple LLaMA variants and long-context tasks show near full-attention performance with significant end-to-end speedups (up to 2.1x vs full attention, 2.17x vs Quest) and robust performance on needle-in-the-haystack, PG-19 perplexity, and LongBench. The optimal token re-selection layer position depends on the model family, but two-layer re-selection consistently yields strong results, and a cache-correction mechanism mitigates KV-distribution drift during long generation.

Abstract

Large language models (LLMs) have driven significant advancements across diverse NLP tasks, with long-context models gaining prominence for handling extended inputs. However, the expanding key-value (KV) cache size required by Transformer architectures intensifies the memory constraints, particularly during the decoding phase, creating a significant bottleneck. Existing sparse attention mechanisms designed to address this bottleneck have two limitations: (1) they often fail to reliably identify the most relevant tokens for attention, and (2) they overlook the spatial coherence of token selection across consecutive Transformer layers, which can lead to performance degradation and substantial overhead in token selection. This paper introduces TidalDecode, a simple yet effective algorithm and system for fast and accurate LLM decoding through position persistent sparse attention. TidalDecode leverages the spatial coherence of tokens selected by existing sparse attention methods and introduces a few token selection layers that perform full attention to identify the tokens with the highest attention scores, while all other layers perform sparse attention with the pre-selected tokens. This design enables TidalDecode to substantially reduce the overhead of token selection for sparse attention without sacrificing the quality of the generated results. Evaluation on a diverse set of LLMs and tasks shows that TidalDecode closely matches the generative performance of full attention methods while reducing the LLM decoding latency by up to 2.1x.

TidalDecode: Fast and Accurate LLM Decoding with Position Persistent Sparse Attention

TL;DR

TidalDecode addresses the memory and latency bottlenecks of decoding long-context LLMs by introducing Position Persistent Sparse Attention (PPSA) and a KV cache correction mechanism. It leverages overlap of high-attention tokens across consecutive layers to reuse selected token sets, using two token-selection layers to identify top-k tokens and performing PPSA on the rest, thereby reducing token-selection overhead without sacrificing quality. Empirical results across multiple LLaMA variants and long-context tasks show near full-attention performance with significant end-to-end speedups (up to 2.1x vs full attention, 2.17x vs Quest) and robust performance on needle-in-the-haystack, PG-19 perplexity, and LongBench. The optimal token re-selection layer position depends on the model family, but two-layer re-selection consistently yields strong results, and a cache-correction mechanism mitigates KV-distribution drift during long generation.

Abstract

Large language models (LLMs) have driven significant advancements across diverse NLP tasks, with long-context models gaining prominence for handling extended inputs. However, the expanding key-value (KV) cache size required by Transformer architectures intensifies the memory constraints, particularly during the decoding phase, creating a significant bottleneck. Existing sparse attention mechanisms designed to address this bottleneck have two limitations: (1) they often fail to reliably identify the most relevant tokens for attention, and (2) they overlook the spatial coherence of token selection across consecutive Transformer layers, which can lead to performance degradation and substantial overhead in token selection. This paper introduces TidalDecode, a simple yet effective algorithm and system for fast and accurate LLM decoding through position persistent sparse attention. TidalDecode leverages the spatial coherence of tokens selected by existing sparse attention methods and introduces a few token selection layers that perform full attention to identify the tokens with the highest attention scores, while all other layers perform sparse attention with the pre-selected tokens. This design enables TidalDecode to substantially reduce the overhead of token selection for sparse attention without sacrificing the quality of the generated results. Evaluation on a diverse set of LLMs and tasks shows that TidalDecode closely matches the generative performance of full attention methods while reducing the LLM decoding latency by up to 2.1x.
Paper Structure (21 sections, 1 equation, 9 figures, 12 tables, 1 algorithm)

This paper contains 21 sections, 1 equation, 9 figures, 12 tables, 1 algorithm.

Figures (9)

  • Figure 1: The heatmap for one decoding step of Llama3-8B-Instruct gradient-ai-llama-3-8B, where columns and rows indicate different Transformer layers and tokens in the KV cache, respectively. For each layer, the 5 tokens (10% sparsity) with the highest attention scores of the first attention head are highlighted in yellow, which are the tokens used for sparse attention. We feed an input prompt "Use only the provided search results to write a high-quality, concise answer to the question.\\ n<|begin_of_text|>\\ n The magic number is: 15213. \\ n\\ n\\ n Question: What is the magic number? Keep the response short and direct. Answer: ", and the LLM outputs "15213". The results show strong spatial coherence of tokens chosen for sparse attention in the decoding step.
  • Figure 2: An overview of the decoding step in TidalDecode, which performs full attention for the first two layers, full attention with token selection for the third layer and a middle layer, and position persistent sparse attention for all other layers.
  • Figure 3: By retrieving the top-256 tokens from a 100K-context-length Needle-in-the-Haystack test conducted on PG-19-mini, \ref{['fig:correction_overlap_matrix_256']} shows the overlap ratio of tokens with the highest attention scores across layers, showing that consecutive layers tend to share a large number of critical tokens. \ref{['fig:recall_rates_bar_256']} depicts the recall rates, indicating that different choices of re-selection layers have a high impact on the recall rates --- there is a clear peak, delineating the optimal layers for token re-selection.
  • Figure 4: Cache Correction
  • Figure 5: Perplexity evaluation on the PG-19 dataset from 0 to 32K tokens. The results compare TidalDecode with different token re-selection layers (L9, L13, L15) to Quest across token budgets (2048 \ref{['fig:perplexity_2048']}, 4096 \ref{['fig:perplexity_4096']}). Lower perplexity indicates better model performance. Full refers to dense attention as baseline.
  • ...and 4 more figures