Table of Contents
Fetching ...

Stem: Rethinking Causal Information Flow in Sparse Attention

Lin Niu, Xin Luo, Linchuan Xie, Yifu Sun, Guanghua Yu, Jianchen Zhu, S Kevin Zhou

TL;DR

This paper proposes Stem, a novel, plug-and-play sparsity module aligned with information flow that employs the Token Position-Decay strategy, applying position-dependent top-k within each layer to retain initial tokens for recursive dependencies.

Abstract

The quadratic computational complexity of self-attention remains a fundamental bottleneck for scaling Large Language Models (LLMs) to long contexts, particularly during the pre-filling phase. In this paper, we rethink the causal attention mechanism from the perspective of information flow. Due to causal constraints, tokens at initial positions participate in the aggregation of every subsequent token. However, existing sparse methods typically apply a uniform top-k selection across all token positions within a layer, ignoring the cumulative dependency of token information inherent in causal architectures. To address this, we propose Stem, a novel, plug-and-play sparsity module aligned with information flow. First, Stem employs the Token Position-Decay strategy, applying position-dependent top-k within each layer to retain initial tokens for recursive dependencies. Second, to preserve information-rich tokens, Stem utilizes the Output-Aware Metric. It prioritizes high-impact tokens based on approximate output magnitude. Extensive evaluations demonstrate that Stem achieves superior accuracy with reduced computation and pre-filling latency.

Stem: Rethinking Causal Information Flow in Sparse Attention

TL;DR

This paper proposes Stem, a novel, plug-and-play sparsity module aligned with information flow that employs the Token Position-Decay strategy, applying position-dependent top-k within each layer to retain initial tokens for recursive dependencies.

Abstract

The quadratic computational complexity of self-attention remains a fundamental bottleneck for scaling Large Language Models (LLMs) to long contexts, particularly during the pre-filling phase. In this paper, we rethink the causal attention mechanism from the perspective of information flow. Due to causal constraints, tokens at initial positions participate in the aggregation of every subsequent token. However, existing sparse methods typically apply a uniform top-k selection across all token positions within a layer, ignoring the cumulative dependency of token information inherent in causal architectures. To address this, we propose Stem, a novel, plug-and-play sparsity module aligned with information flow. First, Stem employs the Token Position-Decay strategy, applying position-dependent top-k within each layer to retain initial tokens for recursive dependencies. Second, to preserve information-rich tokens, Stem utilizes the Output-Aware Metric. It prioritizes high-impact tokens based on approximate output magnitude. Extensive evaluations demonstrate that Stem achieves superior accuracy with reduced computation and pre-filling latency.
Paper Structure (33 sections, 19 equations, 5 figures, 5 tables, 1 algorithm)

This paper contains 33 sections, 19 equations, 5 figures, 5 tables, 1 algorithm.

Figures (5)

  • Figure 1: Latency comparison (ms) on H20 GPU. Results are reported as Attention Kernel Time/Total Time.
  • Figure 2: Visualization of recursive error propagation. The diagram depicts the impact of sparsification across layers $l \to l+2$ based on Eq. (1). Red circles indicate sparse tokens, while blue circles indicate dense tokens. Pruning the initial token $V_1^{(l+1)}$ triggers a global distortion that affects all tokens in the next layer (dashed red connections), whereas pruning the last token $V_N^{(l)}$ results in only a local error confined to the tail.
  • Figure 3: Sensitivity analysis of token position segments. The X-axis represents the specific token interval subject to sparsification. Y-axis shows head logit MSE loss. Curves compare dynamic ratios vs. fixed budgets.
  • Figure 4: Pipeline of Stem. The figure illustrates the framework workflow and compares the sparsity budget schedules between the standard Uniform Top-$k$ and our Token Position-Decay strategy.
  • Figure 5: Hyperparameter ablation studies on LongBench.