Table of Contents
Fetching ...

S2O: Early Stopping for Sparse Attention via Online Permutation

Yu Zhang, Songwei Liu, Chenqian Yan, Sheng Lin, Beichen Ning, Fangmin Chen, Xing Wang

TL;DR

Inspired by virtual-to-physical address mapping in memory systems, S2O revisits and factorizes FlashAttention execution, enabling inference to load non-contiguous tokens rather than a contiguous span in the original order.

Abstract

Attention scales quadratically with sequence length, fundamentally limiting long-context inference. Existing block-granularity sparsification can reduce latency, but coarse blocks impose an intrinsic sparsity ceiling, making further improvements difficult even with carefully engineered designs. We present S2O, which performs early stopping for sparse attention via online permutation. Inspired by virtual-to-physical address mapping in memory systems, S2O revisits and factorizes FlashAttention execution, enabling inference to load non-contiguous tokens rather than a contiguous span in the original order. Motivated by fine-grained structures in attention heatmaps, we transform explicit permutation into an online, index-guided, discrete loading policy; with extremely lightweight preprocessing and index-remapping overhead, it concentrates importance on a small set of high-priority blocks. Building on this importance-guided online permutation for loading, S2O further introduces an early-stopping rule: computation proceeds from high to low importance; once the current block score falls below a threshold, S2O terminates early and skips the remaining low-contribution blocks, thereby increasing effective sparsity and reducing computation under a controlled error budget. As a result, S2O substantially raises the practical sparsity ceiling. On Llama-3.1-8B under a 128K context, S2O reduces single-operator MSE by 3.82$\times$ at matched sparsity, and reduces prefill compute density by 3.31$\times$ at matched MSE; meanwhile, it preserves end-to-end accuracy and achieves 7.51$\times$ attention and 3.81$\times$ end-to-end speedups.

S2O: Early Stopping for Sparse Attention via Online Permutation

TL;DR

Inspired by virtual-to-physical address mapping in memory systems, S2O revisits and factorizes FlashAttention execution, enabling inference to load non-contiguous tokens rather than a contiguous span in the original order.

Abstract

Attention scales quadratically with sequence length, fundamentally limiting long-context inference. Existing block-granularity sparsification can reduce latency, but coarse blocks impose an intrinsic sparsity ceiling, making further improvements difficult even with carefully engineered designs. We present S2O, which performs early stopping for sparse attention via online permutation. Inspired by virtual-to-physical address mapping in memory systems, S2O revisits and factorizes FlashAttention execution, enabling inference to load non-contiguous tokens rather than a contiguous span in the original order. Motivated by fine-grained structures in attention heatmaps, we transform explicit permutation into an online, index-guided, discrete loading policy; with extremely lightweight preprocessing and index-remapping overhead, it concentrates importance on a small set of high-priority blocks. Building on this importance-guided online permutation for loading, S2O further introduces an early-stopping rule: computation proceeds from high to low importance; once the current block score falls below a threshold, S2O terminates early and skips the remaining low-contribution blocks, thereby increasing effective sparsity and reducing computation under a controlled error budget. As a result, S2O substantially raises the practical sparsity ceiling. On Llama-3.1-8B under a 128K context, S2O reduces single-operator MSE by 3.82 at matched sparsity, and reduces prefill compute density by 3.31 at matched MSE; meanwhile, it preserves end-to-end accuracy and achieves 7.51 attention and 3.81 end-to-end speedups.
Paper Structure (41 sections, 2 equations, 7 figures, 7 tables, 4 algorithms)

This paper contains 41 sections, 2 equations, 7 figures, 7 tables, 4 algorithms.

Figures (7)

  • Figure 1: Speedup at 128K context length (Llama-3-8B) with sparsity--error trade-off. We report end-to-end latency speedup over FlashAttention and break down the major components, including sparse preprocessing time, attention compute time, and other overheads. We also show each method's sparsity ratio and the corresponding mean squared error (MSE; lower is better), highlighting that our method achieves lower error at higher sparsity.
  • Figure 2: Attention heatmaps under different permutation strategies.(a) Original: The heatmap exhibits abundant line-level (stripe-like) structures. (b) PBS: Following PBS's local $K/V$ permutation strategy, the heatmap still contains substantial redundancy and fails to consistently emphasize salient horizontal stripes. (c) Ours: Our global permutation scheme (Sec. \ref{['sec:seg_rank']}) compacts attention mass into a progressive region from dense (upper-left) to increasingly diffuse. More qualitative heatmaps are provided in Appendix \ref{['app:heatmaps']}.
  • Figure 3: Coordinate-scheduled online $Q$/$K$ permutation. Since the head dimension is typically 128, scattered token accesses can still fully utilize a warp; see Appendix \ref{['sec:perm_overhead_analysis']} for details.
  • Figure 4: S2O workflow.Step 1 (Permutate) builds two lightweight index arrays without moving tensors: (i) an intra-segment query permutation index $Q_{\mathrm{perm}}$ (token-level, segment-local), and (ii) a prefix key/value logical map $KV_{\mathrm{perm}}$ (token-level, global indices) obtained by retrieving segment representatives. Step 2 (Sparse attention) runs attention in two passes: Pass-1 computes a dense intra-segment causal window to initialize online-softmax states; Pass-2 resumes the states and processes the historical prefix in the retrieved order with early stopping.
  • Figure 5: Evaluation overview on Llama-3.1-8B. (a) Operator-level accuracy under matched sparsity. (b) Attention speedup: dark bars denote attention time and light bars denote preprocessing time; the corresponding sparsity(s) and MSE are also reported. (c) End-to-end prefill latency breakdown across context lengths: dark bars denote total attention-related time (including preprocessing) and light bars denote all remaining overheads.
  • ...and 2 more figures