S2-Attention: Hardware-Aware Context Sharding Among Attention Heads
Xihui Lin, Yunan Zhang, Suyu Ge, Liliang Ren, Barun Patra, Vishrav Chaudhary, Hao Peng, Xia Song
TL;DR
Transformers’ quadratic attention poses large training and serving costs for modern LLMs. The authors present S2-Attention, a Triton-based, hardware-aware sparse-attention kernel that supports per-head and per-context-range sharding, complemented by a KV-efficient sparsity strategy and Merge-Q to reduce memory traffic. They also introduce HHST, a Head-Heterogeneous Strided Transformer that combines heterogeneous context sharding with a hybrid dense-sparse architecture, achieving strong downstream performance with substantial wall-clock speedups, including up to 25× training and 4.5× inference improvements at long contexts. The work provides practical guidelines for designing efficient sparse attention, demonstrates effectiveness across 1.3B and 7B models, and releases an open-source toolkit interoperable with Megatron, PyTorch, HuggingFace, and vLLM to accelerate research and deployment.
Abstract
Sparse attention, which selectively attends to a subset of tokens in the context was supposed to be efficient. However, its theoretical reduction in FLOPs has rarely translated into wall-clock speed-up over its dense attention counterparts due to the lack of hardware-aware optimizations like FlashAttention. Meanwhile, it remains unclear whether sparse attention can maintain the model's quality at a scale of today's large language models (LLMs) and how. This paper presents Sparsely-Sharded(S2) Attention, a Triton library that provides kernel optimization for sparse attention customizable at both per-head and per-context-range levels. S2-Attention enables the exploration of novel and high-performance sparse attention techniques, which we demonstrate through extensive ablations across a wide range of sparse attention designs at various model scales. From these insights, we present several basic guidelines to design sparse attention that can achieve not only practical efficiency improvements, but also strong downstream performance. To achieve high parallelization and optimized memory IO, sparse attention should shard the context heterogeneously across attention heads, where each head attends to a different subset of tokens while collectively covering the full context. Meanwhile, we find hybrid architectures combining sparse and dense attention particularly beneficial in practice. S2-Attention achieves wall-clock speedup of 8.79X, 15.87X, 25.3X compared to the strong FlashAttention-2 baseline with strong downstream performance on-par with full attention and perfect retrieval performance at a 128k context length. At inference, for 7B models, our model, with the help of our S2-Attention kernel, achieves 4.5x speed-up compared to dense counterparts. S2-Attention is released with easy-to-customize APIs for direct usage in Megatron and vLLM.
