FSA: An Alternative Efficient Implementation of Native Sparse Attention Kernel
Ran Yan, Youhe Jiang, Zhuoming Chen, Haohui Mai, Beidi Chen, Binhang Yuan
TL;DR
This work tackles the high $O(N^2)$ costs of long-context attention in large language models by enhancing sparse attention via Flash Sparse Attention (FSA). FSA inverts NSA's kernel loop to process KV blocks first, removing padding and reducing memory traffic, while employing index-based token selection, an online softmax pass, and a two-stage reduction. The authors provide a theoretical memory/FLOP analysis and extensive experiments on NVIDIA H20/H200 GPUs, reporting kernel-level speedups up to $3.5\times$, end-to-end training speedups up to $1.25\times$, and inference prefill speedups up to $1.36\times$ over NSA, with larger gains relative to full attention. The results demonstrate robust improvements across varied GQA group sizes and sequence lengths, underscoring the importance of algorithm–system co-design for practical hardware-efficient sparse attention.
Abstract
Recent advance in sparse attention mechanisms has demonstrated strong potential for reducing the computational cost of long-context training and inference in large language models (LLMs). Native Sparse Attention (NSA), one state-of-the-art approach, introduces natively trainable, hardware-aligned sparse attention that delivers substantial system-level performance boost while maintaining accuracy comparable to full attention. However, the kernel implementation of NSA forces a loop order that is only efficient with a relatively large number of query heads in each Grouped Query Attention (GQA) group, whereas existing LLMs widely adopt much smaller number of query heads in each GQA group -- such an inconsistency significantly limits the applicability of this sparse algorithmic advance. In this work, we propose Flash Sparse Attention (FSA), an alternative kernel implementation that enables efficient NSA computation across a wide range of popular LLMs with varied smaller number of heads in each GQA group on modern GPUs. Compared to vanilla NSA kernel implementation, our empirical evaluation demonstrates that FSA achieves (i) up to 3.5x and on average 1.6x kernel-level latency reduction, (ii) up to 1.25x and 1.09x on average end-to-end training speedup on state-of-the-art LLMs, and (iii) up to 1.36x and 1.11x on average for prefill-phase speedup in LLM generative inference. Github Repo at https://github.com/Relaxed-System-Lab/Flash-Sparse-Attention.
