Table of Contents
Fetching ...

Efficiently Dispatching Flash Attention For Partially Filled Attention Masks

Agniv Sharma, Jonas Geiping

TL;DR

Binary Block Masking is introduced, a highly efficient modification that enhances Flash Attention by making it mask-aware and proposes two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks.

Abstract

Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices. Examples include attention masks designed to reduce the quadratic complexity of attention, sequence packing techniques, and recent innovations like tree masking for fast validation in MEDUSA. Despite the inherent sparsity in these matrices, the state-of-the-art algorithm Flash Attention still processes them with quadratic complexity as though they were dense. In this paper, we introduce Binary Block Masking, a highly efficient modification that enhances Flash Attention by making it mask-aware. We further propose two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks. Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement. The implementation will be publicly released to foster further research and application.

Efficiently Dispatching Flash Attention For Partially Filled Attention Masks

TL;DR

Binary Block Masking is introduced, a highly efficient modification that enhances Flash Attention by making it mask-aware and proposes two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks.

Abstract

Transformers are widely used across various applications, many of which yield sparse or partially filled attention matrices. Examples include attention masks designed to reduce the quadratic complexity of attention, sequence packing techniques, and recent innovations like tree masking for fast validation in MEDUSA. Despite the inherent sparsity in these matrices, the state-of-the-art algorithm Flash Attention still processes them with quadratic complexity as though they were dense. In this paper, we introduce Binary Block Masking, a highly efficient modification that enhances Flash Attention by making it mask-aware. We further propose two optimizations: one tailored for masks with contiguous non-zero patterns and another for extremely sparse masks. Our experiments on attention masks derived from real-world scenarios demonstrate up to a 9x runtime improvement. The implementation will be publicly released to foster further research and application.
Paper Structure (19 sections, 14 figures, 10 tables)

This paper contains 19 sections, 14 figures, 10 tables.

Figures (14)

  • Figure 1: Proposed algorithm for masking flash attention
  • Figure 2: The result and performance of the RCM bandwidth reduction algorithm when computing a sparse attention mask.
  • Figure 3: Mask Visualization and Performance Comparison for MEDUSA tree mask
  • Figure 4: Mask Visualization and Performance Comparison for packed ALPACA dataset
  • Figure 5: Mask Visualization and Performance Comparison for Longformer
  • ...and 9 more figures