Table of Contents
Fetching ...

S2-Attention: Hardware-Aware Context Sharding Among Attention Heads

Xihui Lin, Yunan Zhang, Suyu Ge, Liliang Ren, Barun Patra, Vishrav Chaudhary, Hao Peng, Xia Song

TL;DR

Transformers’ quadratic attention poses large training and serving costs for modern LLMs. The authors present S2-Attention, a Triton-based, hardware-aware sparse-attention kernel that supports per-head and per-context-range sharding, complemented by a KV-efficient sparsity strategy and Merge-Q to reduce memory traffic. They also introduce HHST, a Head-Heterogeneous Strided Transformer that combines heterogeneous context sharding with a hybrid dense-sparse architecture, achieving strong downstream performance with substantial wall-clock speedups, including up to 25× training and 4.5× inference improvements at long contexts. The work provides practical guidelines for designing efficient sparse attention, demonstrates effectiveness across 1.3B and 7B models, and releases an open-source toolkit interoperable with Megatron, PyTorch, HuggingFace, and vLLM to accelerate research and deployment.

Abstract

Sparse attention, which selectively attends to a subset of tokens in the context was supposed to be efficient. However, its theoretical reduction in FLOPs has rarely translated into wall-clock speed-up over its dense attention counterparts due to the lack of hardware-aware optimizations like FlashAttention. Meanwhile, it remains unclear whether sparse attention can maintain the model's quality at a scale of today's large language models (LLMs) and how. This paper presents Sparsely-Sharded(S2) Attention, a Triton library that provides kernel optimization for sparse attention customizable at both per-head and per-context-range levels. S2-Attention enables the exploration of novel and high-performance sparse attention techniques, which we demonstrate through extensive ablations across a wide range of sparse attention designs at various model scales. From these insights, we present several basic guidelines to design sparse attention that can achieve not only practical efficiency improvements, but also strong downstream performance. To achieve high parallelization and optimized memory IO, sparse attention should shard the context heterogeneously across attention heads, where each head attends to a different subset of tokens while collectively covering the full context. Meanwhile, we find hybrid architectures combining sparse and dense attention particularly beneficial in practice. S2-Attention achieves wall-clock speedup of 8.79X, 15.87X, 25.3X compared to the strong FlashAttention-2 baseline with strong downstream performance on-par with full attention and perfect retrieval performance at a 128k context length. At inference, for 7B models, our model, with the help of our S2-Attention kernel, achieves 4.5x speed-up compared to dense counterparts. S2-Attention is released with easy-to-customize APIs for direct usage in Megatron and vLLM.

S2-Attention: Hardware-Aware Context Sharding Among Attention Heads

TL;DR

Transformers’ quadratic attention poses large training and serving costs for modern LLMs. The authors present S2-Attention, a Triton-based, hardware-aware sparse-attention kernel that supports per-head and per-context-range sharding, complemented by a KV-efficient sparsity strategy and Merge-Q to reduce memory traffic. They also introduce HHST, a Head-Heterogeneous Strided Transformer that combines heterogeneous context sharding with a hybrid dense-sparse architecture, achieving strong downstream performance with substantial wall-clock speedups, including up to 25× training and 4.5× inference improvements at long contexts. The work provides practical guidelines for designing efficient sparse attention, demonstrates effectiveness across 1.3B and 7B models, and releases an open-source toolkit interoperable with Megatron, PyTorch, HuggingFace, and vLLM to accelerate research and deployment.

Abstract

Sparse attention, which selectively attends to a subset of tokens in the context was supposed to be efficient. However, its theoretical reduction in FLOPs has rarely translated into wall-clock speed-up over its dense attention counterparts due to the lack of hardware-aware optimizations like FlashAttention. Meanwhile, it remains unclear whether sparse attention can maintain the model's quality at a scale of today's large language models (LLMs) and how. This paper presents Sparsely-Sharded(S2) Attention, a Triton library that provides kernel optimization for sparse attention customizable at both per-head and per-context-range levels. S2-Attention enables the exploration of novel and high-performance sparse attention techniques, which we demonstrate through extensive ablations across a wide range of sparse attention designs at various model scales. From these insights, we present several basic guidelines to design sparse attention that can achieve not only practical efficiency improvements, but also strong downstream performance. To achieve high parallelization and optimized memory IO, sparse attention should shard the context heterogeneously across attention heads, where each head attends to a different subset of tokens while collectively covering the full context. Meanwhile, we find hybrid architectures combining sparse and dense attention particularly beneficial in practice. S2-Attention achieves wall-clock speedup of 8.79X, 15.87X, 25.3X compared to the strong FlashAttention-2 baseline with strong downstream performance on-par with full attention and perfect retrieval performance at a 128k context length. At inference, for 7B models, our model, with the help of our S2-Attention kernel, achieves 4.5x speed-up compared to dense counterparts. S2-Attention is released with easy-to-customize APIs for direct usage in Megatron and vLLM.
Paper Structure (26 sections, 2 equations, 8 figures, 1 table)

This paper contains 26 sections, 2 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: Illustration of S2-Attention with four attention heads on a hypothetical GPU with 4 thread blocks. Each attention head is allocated with a shard of the context.
  • Figure 2: Training Efficiency and long-context analysis of S2-Attention. Our model, implemented with our kernel, achieves substantial reduction in latency compared to FlashAttention-2 (a). It also achieves perfect retrieval performance at a 128K context length (b).
  • Figure 3: Illustration of why KV eviction methods can cause more fragmentation. Here we show 3 pages of KV blocks containing 2 requests. Despite many tokens were evicted, the released slots can hardly be utilized by other requests, leading to higher rate of internal fragmentation.
  • Figure 4: Illustration of S2-Attention Implementation. Left: Directly apply FlashAttention-2 tiling to sparse attention. Right: MergeQ, which adaptively merge queries sharing the same KV together when loading into the SRAM, thus reduce redundant KV loading and improve IO efficiency.
  • Figure 5: (a): The dilated attention based on relative position as an example of sparse attention that is not KV-efficient. E.g., step 5 attends to KV at positions 1, 3, 5, while step 4 attends to 0, 2, 4. This results in requiring full KV cache. Although it suggests nearly 50% memory savings on paper, it actually requires storing the full KV cache in practice. (b) All these attention patterns are KV-efficient, as they get pushed to KV-cache when first encountered at decoding, then continuously being attended for several steps before it finally gets evicted (e.g., all tokens in left figure, and token 0 in right figure) and never gets attended again, or remained attended for all future tokens (e.g., tokens 0, 2, 4 in middle figure and tokens 2, 4 in right figure). The arrows show that they all share a "vertical line" pattern.
  • ...and 3 more figures