Table of Contents
Fetching ...

Prism: Spectral-Aware Block-Sparse Attention

Xinghao Wang, Pengyu Wang, Xiaoran Liu, Fangxu Liu, Jason Chu, Kai Song, Xipeng Qiu

TL;DR

This work tackles the bottleneck in long-context processing with block-sparse attention by revealing that mean pooling, when combined with Rotary Positional Embeddings (RoPE), acts as a low-pass filter that erases high-frequency local positional information. It introduces Prism, a training-free spectral-aware framework that splits block importance estimation into high- and low-frequency branches and uses energy-based calibration to restore attenuated signals, enabling purely block-level scoring. Prism matches the accuracy of full attention while delivering up to $5.1\times$ speedups across diverse long-context tasks and modalities, including language and video benchmarks. The approach offers a practical, scalable solution for efficient long-context LLMs and multimodal models, with broad applicability to RoPE variants.

Abstract

Block-sparse attention is promising for accelerating long-context LLM pre-filling, yet identifying relevant blocks efficiently remains a bottleneck. Existing methods typically employ coarse-grained attention as a proxy for block importance estimation, but often resort to expensive token-level searching or scoring, resulting in significant selection overhead. In this work, we trace the inaccuracy of standard coarse-grained attention via mean pooling to a theoretical root cause: the interaction between mean pooling and Rotary Positional Embeddings (RoPE). We prove that mean pooling acts as a low-pass filter that induces destructive interference in high-frequency dimensions, effectively creating a "blind spot" for local positional information (e.g., slash patterns). To address this, we introduce Prism, a training-free spectral-aware approach that decomposes block selection into high-frequency and low-frequency branches. By applying energy-based temperature calibration, Prism restores the attenuated positional signals directly from pooled representations, enabling block importance estimation using purely block-level operations, thereby improving efficiency. Extensive evaluations confirm that Prism maintains accuracy parity with full attention while delivering up to $\mathbf{5.1\times}$ speedup.

Prism: Spectral-Aware Block-Sparse Attention

TL;DR

This work tackles the bottleneck in long-context processing with block-sparse attention by revealing that mean pooling, when combined with Rotary Positional Embeddings (RoPE), acts as a low-pass filter that erases high-frequency local positional information. It introduces Prism, a training-free spectral-aware framework that splits block importance estimation into high- and low-frequency branches and uses energy-based calibration to restore attenuated signals, enabling purely block-level scoring. Prism matches the accuracy of full attention while delivering up to speedups across diverse long-context tasks and modalities, including language and video benchmarks. The approach offers a practical, scalable solution for efficient long-context LLMs and multimodal models, with broad applicability to RoPE variants.

Abstract

Block-sparse attention is promising for accelerating long-context LLM pre-filling, yet identifying relevant blocks efficiently remains a bottleneck. Existing methods typically employ coarse-grained attention as a proxy for block importance estimation, but often resort to expensive token-level searching or scoring, resulting in significant selection overhead. In this work, we trace the inaccuracy of standard coarse-grained attention via mean pooling to a theoretical root cause: the interaction between mean pooling and Rotary Positional Embeddings (RoPE). We prove that mean pooling acts as a low-pass filter that induces destructive interference in high-frequency dimensions, effectively creating a "blind spot" for local positional information (e.g., slash patterns). To address this, we introduce Prism, a training-free spectral-aware approach that decomposes block selection into high-frequency and low-frequency branches. By applying energy-based temperature calibration, Prism restores the attenuated positional signals directly from pooled representations, enabling block importance estimation using purely block-level operations, thereby improving efficiency. Extensive evaluations confirm that Prism maintains accuracy parity with full attention while delivering up to speedup.
Paper Structure (22 sections, 22 equations, 11 figures, 3 tables)

This paper contains 22 sections, 22 equations, 11 figures, 3 tables.

Figures (11)

  • Figure 1: Spectral Disentanglement of Attention Patterns. We visualize the attention score matrices computed using different spectral bands of RoPE. (Left) Low-Frequency Band: Captures global semantic dependencies (e.g., block-sparse patterns / vertical lines), acting as the semantic backbone. (Middle) High-Frequency Band: Strictly encodes fine-grained relative locality (e.g., slash lines), which is critical for local coherence. (Right) Full Spectrum: The superposition of both patterns.
  • Figure 2: Spectral attenuation factor $\lambda_j(B)$ with block size $B=128$ and head dimension $d=128$.
  • Figure 3: Comparison of Query RMS norms before and after pooling. Left (Token-level): While the Semantic Zone (blue) holds the highest energy, the Dead Zone (green) maintains a robust magnitude ($\text{RMS} \approx 1.0$), confirming that high-frequency dimensions are actively utilized by the pre-trained model. Right (Block-pooled): After pooling, energy in the Dead Zone collapses to near-zero due to destructive interference, while the Semantic Zone preserves its magnitude.
  • Figure 4: PyTorch-style implementation of Prism. Prism exclusively uses block-level operations for best efficiency. See Appendix \ref{['sec:appendix_topp']} for top_p implementation.
  • Figure 5: Language modeling performance on PG19. We compare the Perplexity Degradation ($\Delta$PPL, solid lines, left axis) and Speedup (bars, right axis) across sequence lengths. Prism achieves a double win: it shows no perplexity degradation (sticking to the $\Delta \approx 0$ line) while delivering the highest speedup ($\mathbf{5.1\times}$ at 128K), significantly outperforming baselines that trade off accuracy for speed or suffer from high selection overhead.
  • ...and 6 more figures