Table of Contents
Fetching ...

FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference

Xunhao Lai, Jianqiao Lu, Yao Luo, Yiyuan Ma, Xun Zhou

TL;DR

FlexPrefill introduces a dynamic sparse attention mechanism for long-sequence inference, adapting per-head sparsity and patterns in real time. It combines Query-Aware Sparse Pattern Determination with Cumulative-Attention Based Index Selection, guided by Jensen-Shannon divergence and a cumulative attention threshold gamma. Across LLaMA-3.1, GLM-4, Yi-9B, and Qwen-2, on RULER and Infinite Bench, it achieves notable speedups while preserving or improving accuracy compared to strong baselines. The approach demonstrates robust performance across varying input lengths and task types, offering a practical path to efficient long-context inference in modern LLMs.

Abstract

Large language models (LLMs) encounter computational challenges during long-sequence inference, especially in the attention pre-filling phase, where the complexity grows quadratically with the prompt length. Previous efforts to mitigate these challenges have relied on fixed sparse attention patterns or identifying sparse attention patterns based on limited cases. However, these methods lacked the flexibility to efficiently adapt to varying input demands. In this paper, we introduce FlexPrefill, a Flexible sparse Pre-filling mechanism that dynamically adjusts sparse attention patterns and computational budget in real-time to meet the specific requirements of each input and attention head. The flexibility of our method is demonstrated through two key innovations: 1) Query-Aware Sparse Pattern Determination: By measuring Jensen-Shannon divergence, this component adaptively switches between query-specific diverse attention patterns and predefined attention patterns. 2) Cumulative-Attention Based Index Selection: This component dynamically selects query-key indexes to be computed based on different attention patterns, ensuring the sum of attention scores meets a predefined threshold. FlexPrefill adaptively optimizes the sparse pattern and sparse ratio of each attention head based on the prompt, enhancing efficiency in long-sequence inference tasks. Experimental results show significant improvements in both speed and accuracy over prior methods, providing a more flexible and efficient solution for LLM inference.

FlexPrefill: A Context-Aware Sparse Attention Mechanism for Efficient Long-Sequence Inference

TL;DR

FlexPrefill introduces a dynamic sparse attention mechanism for long-sequence inference, adapting per-head sparsity and patterns in real time. It combines Query-Aware Sparse Pattern Determination with Cumulative-Attention Based Index Selection, guided by Jensen-Shannon divergence and a cumulative attention threshold gamma. Across LLaMA-3.1, GLM-4, Yi-9B, and Qwen-2, on RULER and Infinite Bench, it achieves notable speedups while preserving or improving accuracy compared to strong baselines. The approach demonstrates robust performance across varying input lengths and task types, offering a practical path to efficient long-context inference in modern LLMs.

Abstract

Large language models (LLMs) encounter computational challenges during long-sequence inference, especially in the attention pre-filling phase, where the complexity grows quadratically with the prompt length. Previous efforts to mitigate these challenges have relied on fixed sparse attention patterns or identifying sparse attention patterns based on limited cases. However, these methods lacked the flexibility to efficiently adapt to varying input demands. In this paper, we introduce FlexPrefill, a Flexible sparse Pre-filling mechanism that dynamically adjusts sparse attention patterns and computational budget in real-time to meet the specific requirements of each input and attention head. The flexibility of our method is demonstrated through two key innovations: 1) Query-Aware Sparse Pattern Determination: By measuring Jensen-Shannon divergence, this component adaptively switches between query-specific diverse attention patterns and predefined attention patterns. 2) Cumulative-Attention Based Index Selection: This component dynamically selects query-key indexes to be computed based on different attention patterns, ensuring the sum of attention scores meets a predefined threshold. FlexPrefill adaptively optimizes the sparse pattern and sparse ratio of each attention head based on the prompt, enhancing efficiency in long-sequence inference tasks. Experimental results show significant improvements in both speed and accuracy over prior methods, providing a more flexible and efficient solution for LLM inference.

Paper Structure

This paper contains 51 sections, 18 equations, 9 figures, 11 tables.

Figures (9)

  • Figure 2: Comparison of attention patterns in different attention heads. The Diverse patterns (a) show scattered attention with independent blocks across query positions, while the Structured patterns (b) exhibit attention focused along certain structures.
  • Figure 3: Adaptive sparsity ratios across different attention heads and layers for varying sample complexities. Each heatmap shows the sparsity rate of different context lengths and task types given a fixed attention score coverage, where darker colors indicate more attention calculations. The sparsity distribution of different attention heads varies with different sample types (a, b) and different context lengths (b, c)
  • Figure 4: Comparison of our method with MInference and Fixed budget Vertical-Slash attention, showing the trade-off between model performance and attention latency. Our method consistently outperforms MInference and Fixed Budget Vertical-Slash across different attention latencies. More details are provided in \ref{['app:fig4_detail']}
  • Figure 5: Comparison of our method under different thresholds $\tau$ indicates that an appropriate threshold $\tau$ improves model performance.
  • Figure 6: Sparse Attention Latency Breakdown Comparison across different context lengths. The graph shows the contribution of different components (Sparse Attention, Index Search, Pattern Search, and Representative Attention) to the overall latency for various input lengths. As input length increases, the proportion of time spent on sparse attention computation grows, while other components' relative contributions decrease.
  • ...and 4 more figures