Table of Contents
Fetching ...

Sparser is Faster and Less is More: Efficient Sparse Attention for Long-Range Transformers

Chao Lou, Zixia Jia, Zilong Zheng, Kewei Tu

TL;DR

Long sequences in autoregressive transformers are hampered by quadratic attention and KV memory costs. The paper introduces SparseK Attention, combining a learnable KV scoring network with a differentiable SparseK operator to select a fixed number of KV pairs per query, achieving linear-time generation and constant memory. Through method extensions, training tricks, and a fused kernel, SparseK outperforms prior sparse attention methods on language modeling and downstream tasks, and can be integrated into pretrained LLMs with minimal fine-tuning. Empirical results show improved perplexity, robust length extrapolation, and favorable efficiency trade-offs, underscoring SparseK as a practical solution for long-range dependency modeling.

Abstract

Accommodating long sequences efficiently in autoregressive Transformers, especially within an extended context window, poses significant challenges due to the quadratic computational complexity and substantial KV memory requirements inherent in self-attention mechanisms. In this work, we introduce SPARSEK Attention, a novel sparse attention mechanism designed to overcome these computational and memory obstacles while maintaining performance. Our approach integrates a scoring network and a differentiable top-k mask operator, SPARSEK, to select a constant number of KV pairs for each query, thereby enabling gradient-based optimization. As a result, SPARSEK Attention offers linear time complexity and constant memory footprint during generation. Experimental results reveal that SPARSEK Attention outperforms previous sparse attention methods and provides significant speed improvements during both training and inference, particularly in language modeling and downstream tasks. Furthermore, our method can be seamlessly integrated into pre-trained Large Language Models (LLMs) with minimal fine-tuning, offering a practical solution for effectively managing long-range dependencies in diverse applications.

Sparser is Faster and Less is More: Efficient Sparse Attention for Long-Range Transformers

TL;DR

Long sequences in autoregressive transformers are hampered by quadratic attention and KV memory costs. The paper introduces SparseK Attention, combining a learnable KV scoring network with a differentiable SparseK operator to select a fixed number of KV pairs per query, achieving linear-time generation and constant memory. Through method extensions, training tricks, and a fused kernel, SparseK outperforms prior sparse attention methods on language modeling and downstream tasks, and can be integrated into pretrained LLMs with minimal fine-tuning. Empirical results show improved perplexity, robust length extrapolation, and favorable efficiency trade-offs, underscoring SparseK as a practical solution for long-range dependency modeling.

Abstract

Accommodating long sequences efficiently in autoregressive Transformers, especially within an extended context window, poses significant challenges due to the quadratic computational complexity and substantial KV memory requirements inherent in self-attention mechanisms. In this work, we introduce SPARSEK Attention, a novel sparse attention mechanism designed to overcome these computational and memory obstacles while maintaining performance. Our approach integrates a scoring network and a differentiable top-k mask operator, SPARSEK, to select a constant number of KV pairs for each query, thereby enabling gradient-based optimization. As a result, SPARSEK Attention offers linear time complexity and constant memory footprint during generation. Experimental results reveal that SPARSEK Attention outperforms previous sparse attention methods and provides significant speed improvements during both training and inference, particularly in language modeling and downstream tasks. Furthermore, our method can be seamlessly integrated into pre-trained Large Language Models (LLMs) with minimal fine-tuning, offering a practical solution for effectively managing long-range dependencies in diverse applications.

Paper Structure

This paper contains 56 sections, 1 theorem, 22 equations, 8 figures, 9 tables, 3 algorithms.

Key Result

Proposition A.1

The solution of $\textsc{SparseK}\coloneqq \mathop{\mathrm{arg\,min}}\limits_{{\bm{p}}\in \mathbb{C}}|| {\bm{p}}-{\bm{z}} ||^2$ is in the form of ${\bm{p}}^* = \max(\min({\bm{z}} - \tau({\bm{z}}), \bm{1}), \bm{0})$, where $\mathbb{C} = \{{\bm{p}} | \bm{0} \le {\bm{p}} \le \bm{1}, \bm{1}^\top{\bm{p}

Figures (8)

  • Figure 1: Left: $\textsc{SparseK}$ operation in the attention module. KV pairs are scored by $\mathbf{u}$. $\textsc{SparseK}$ computes a threshold for each query ($\tau(\mathbf{u})$) such that the sum of normalized scores is $k$, which is 3 in this example. We select top-$k$ KV pairs (orange cells) to perform attention. Right: the $\textsc{SparseK}$ attention module. We fuse selection and attention in one kernel for efficiency.
  • Figure 2: Perplexity on the held-out set of fine-tuned models. L denotes the training context length.
  • Figure 3: Length extrapolation results. * denotes that the method is training-free. 2,048 is the context length of the original model. 8,192 is the context length in fune-tuning.
  • Figure 4: Training from scratch.
  • Figure 5: Benchmark the speed against FlashAttention-2. ST indicates the straight-through estimator
  • ...and 3 more figures

Theorems & Definitions (2)

  • Proposition A.1
  • proof