Table of Contents
Fetching ...

Block Sparse Flash Attention

Daniel Ohayon, Itay Lamprecht, Itay Hubara, Israel Cohen, Daniel Soudry, Noam Elata

TL;DR

Block-Sparse FlashAttention tackles the long-context bottleneck by introducing a training-free, block-level gating mechanism that prunes value-block processing using calibrated thresholds after computing exact QK scores. This approach preserves exact attention patterns while dramatically reducing both compute and memory traffic, achieving up to 1.10x–1.24x speedups on large-scale models with minimal accuracy loss and strong cross-dataset generalization. The method extends FlashAttention-2 with a lightweight gating step, requiring only offline threshold calibration on a small dataset and a drop-in CUDA kernel, making it practical for deployment in real-world retrieval and reasoning tasks. BSFA demonstrates robust, task-adaptive sparsity, outperforming prior sparse-attention methods that rely on pre- or post-scoring pruning or approximations. The work highlights that exact score-based block selection can yield reliable sparsity and substantial performance gains without architectural changes or quantization.

Abstract

Modern large language models increasingly require long contexts for reasoning and multi-document tasks, but attention's quadratic complexity creates a severe computational bottleneck. We present Block-Sparse FlashAttention (BSFA), a drop-in replacement that accelerates long-context inference while preserving model quality. Unlike methods that predict importance before computing scores, BSFA computes exact query-key similarities to select the top-k most important value blocks for each query. By comparing per-block maximum scores against calibrated thresholds, we skip approximately 50% of the computation and memory transfers for pruned blocks. Our training-free approach requires only a one-time threshold calibration on a small dataset to learn the per-layer and per-head attention score distributions. We provide a CUDA kernel implementation that can be used as a drop-in replacement for FlashAttention. On Llama-3.1-8B, BSFA achieves up to 1.10x speedup on real-world reasoning benchmarks and up to 1.24x for needle-in-a-haystack retrieval tasks while maintaining above 99% baseline accuracy, with certain configurations even improving accuracy by focusing on the most relevant content, substantially outperforming existing sparse attention methods. The implementation is available at https://github.com/Danielohayon/Block-Sparse-Flash-Attention

Block Sparse Flash Attention

TL;DR

Block-Sparse FlashAttention tackles the long-context bottleneck by introducing a training-free, block-level gating mechanism that prunes value-block processing using calibrated thresholds after computing exact QK scores. This approach preserves exact attention patterns while dramatically reducing both compute and memory traffic, achieving up to 1.10x–1.24x speedups on large-scale models with minimal accuracy loss and strong cross-dataset generalization. The method extends FlashAttention-2 with a lightweight gating step, requiring only offline threshold calibration on a small dataset and a drop-in CUDA kernel, making it practical for deployment in real-world retrieval and reasoning tasks. BSFA demonstrates robust, task-adaptive sparsity, outperforming prior sparse-attention methods that rely on pre- or post-scoring pruning or approximations. The work highlights that exact score-based block selection can yield reliable sparsity and substantial performance gains without architectural changes or quantization.

Abstract

Modern large language models increasingly require long contexts for reasoning and multi-document tasks, but attention's quadratic complexity creates a severe computational bottleneck. We present Block-Sparse FlashAttention (BSFA), a drop-in replacement that accelerates long-context inference while preserving model quality. Unlike methods that predict importance before computing scores, BSFA computes exact query-key similarities to select the top-k most important value blocks for each query. By comparing per-block maximum scores against calibrated thresholds, we skip approximately 50% of the computation and memory transfers for pruned blocks. Our training-free approach requires only a one-time threshold calibration on a small dataset to learn the per-layer and per-head attention score distributions. We provide a CUDA kernel implementation that can be used as a drop-in replacement for FlashAttention. On Llama-3.1-8B, BSFA achieves up to 1.10x speedup on real-world reasoning benchmarks and up to 1.24x for needle-in-a-haystack retrieval tasks while maintaining above 99% baseline accuracy, with certain configurations even improving accuracy by focusing on the most relevant content, substantially outperforming existing sparse attention methods. The implementation is available at https://github.com/Danielohayon/Block-Sparse-Flash-Attention

Paper Structure

This paper contains 28 sections, 4 equations, 3 figures, 2 tables, 1 algorithm.

Figures (3)

  • Figure 1: Accuracy-latency trade-offs on RULER for 32K (top), 64K (middle), and 128K (bottom) sequences, with accuracy and speedup measured at the same sequence lengths. BSFA (blue) maintains high accuracy with consistent speedups, outperforming SpargeAttention (orange). Each point represents a different sparsity level, see Table \ref{['tab:main-results']} for specific configurations.
  • Figure 2: Performance on 64K Needle-in-a-Haystack task. BSFA maintains 99% accuracy even at extreme sparsity ($k$=32 blocks) while achieving 1.24$\times$ speedup, demonstrating its ability to dynamically adapt to task requirements. Tasks requiring targeted retrieval can leverage aggressive sparsity without sacrificing accuracy.
  • Figure 3: Accuracy vs. Time-to-First-Token (TTFT) speedup on LongBench benchmark. BSFA achieves consistent speedups while maintaining high accuracy, demonstrating effective cross-dataset generalization with thresholds calibrated on RULER.