Table of Contents
Fetching ...

Trainable Dynamic Mask Sparse Attention

Jingze Shi, Yifan Wu, Yiran Peng, Bingheng Wu, Liangdong Wang, Guang Liu, Yuyu Luo

TL;DR

This work introduces Dynamic Mask Attention (DMA), a trainable dual-aware sparse attention mechanism that combines content-aware dynamic masking with position-aware sparse weights to address the quadratic bottleneck of self-attention in long-context modeling. A dedicated CUDA kernel fuses FlashAttention-like tiling with hardware-efficient skip logic, reducing time from $O(n^2)$ to $O(n \cdot w)$ and memory to $O(n \cdot w)$ for window size $w \ll n$, while preserving end-to-end differentiability. Extensive experiments across scaling laws, multi-query associative recall, downstream benchmarks, and extrapolated retrieval demonstrate that DMA offers a consistent Pareto advantage over state-of-the-art sparse baselines, with up to 10x speedups. The work also provides open-source kernel code to facilitate adoption and further research in efficient long-context transformers.

Abstract

The increasing demand for long-context modeling in large language models (LLMs) is bottlenecked by the quadratic complexity of the standard self-attention mechanism. The community has proposed sparse attention to mitigate this issue. However, position-aware sparse attention methods rely on static sparse structures that lack adaptability to diverse query contexts, while content-aware sparse attention methods depend on heuristic key-value selection, hindering full differentiability. We introduce a trainable dynamic mask sparse attention mechanism, a method that merges the advantages of both position-aware and content-aware approaches. Dynamic Mask Attention (DMA) achieves this through three key innovations: First, it leverages value vector representations to generate content-aware dynamic masks, enabling the model to adaptively identify and attend to critical information. Second, it computes position-aware sparse weights in a hardware-friendly manner, efficiently skipping unnecessary computational regions. Finally, we demonstrate that the introduced dynamic mask and sparse weights do not obstruct gradients, supporting end-to-end training. We have validated the performance of DMA through comprehensive experiments. A large body of experimental evidence shows that DMA consistently holds a Pareto advantage over state-of-the-art sparse attention baselines in tasks including scaling laws, multi-query associative recall, standard benchmarks, and needle in a haystack tests, while also delivering up to a 10x overall speedup. These results highlight its ability to effectively balance model efficiency with long-context modeling capabilities. Our computational kernel code is now open-source at https://github.com/SmallDoges/flash-dmattn to encourage further research and application by the community.

Trainable Dynamic Mask Sparse Attention

TL;DR

This work introduces Dynamic Mask Attention (DMA), a trainable dual-aware sparse attention mechanism that combines content-aware dynamic masking with position-aware sparse weights to address the quadratic bottleneck of self-attention in long-context modeling. A dedicated CUDA kernel fuses FlashAttention-like tiling with hardware-efficient skip logic, reducing time from to and memory to for window size , while preserving end-to-end differentiability. Extensive experiments across scaling laws, multi-query associative recall, downstream benchmarks, and extrapolated retrieval demonstrate that DMA offers a consistent Pareto advantage over state-of-the-art sparse baselines, with up to 10x speedups. The work also provides open-source kernel code to facilitate adoption and further research in efficient long-context transformers.

Abstract

The increasing demand for long-context modeling in large language models (LLMs) is bottlenecked by the quadratic complexity of the standard self-attention mechanism. The community has proposed sparse attention to mitigate this issue. However, position-aware sparse attention methods rely on static sparse structures that lack adaptability to diverse query contexts, while content-aware sparse attention methods depend on heuristic key-value selection, hindering full differentiability. We introduce a trainable dynamic mask sparse attention mechanism, a method that merges the advantages of both position-aware and content-aware approaches. Dynamic Mask Attention (DMA) achieves this through three key innovations: First, it leverages value vector representations to generate content-aware dynamic masks, enabling the model to adaptively identify and attend to critical information. Second, it computes position-aware sparse weights in a hardware-friendly manner, efficiently skipping unnecessary computational regions. Finally, we demonstrate that the introduced dynamic mask and sparse weights do not obstruct gradients, supporting end-to-end training. We have validated the performance of DMA through comprehensive experiments. A large body of experimental evidence shows that DMA consistently holds a Pareto advantage over state-of-the-art sparse attention baselines in tasks including scaling laws, multi-query associative recall, standard benchmarks, and needle in a haystack tests, while also delivering up to a 10x overall speedup. These results highlight its ability to effectively balance model efficiency with long-context modeling capabilities. Our computational kernel code is now open-source at https://github.com/SmallDoges/flash-dmattn to encourage further research and application by the community.

Paper Structure

This paper contains 39 sections, 10 equations, 11 figures, 5 tables, 2 algorithms.

Figures (11)

  • Figure 1: Workflow and Performance of Dynamic Mask Attention. Left: Overall workflow of DMA. The first step projects the input into $Q$, $K$, and $V$. The second step generates content-aware dynamic masks. The third step computes sparse weights. Black solid arrows indicate the forward computation path, while gray dashed arrows represent the backward computation path. Right: Relative performance comparison between full attention and DMA on benchmark tests. DMA achieves higher recall rates and significantly faster speeds while maintaining competitive accuracy.
  • Figure 2: Dynamic Mask Attention Architecture. Left: Content-Aware Mask Computation. The mask computation part of dynamic mask attention. In the outer loop, the stride weight $\Delta$ and gate weight $A$ are loaded into high-speed SRAM, and in the inner loop, the zero-order hold method is used to loop through the $V$ blocks loaded into SRAM, sampling from it to generate content-aware $K$ masks. These masks are then causally broadcast to the length of $Q$ in HBM. Finally, in the outer loop, all mask blocks are concatenated to form the final content-aware sparse dynamic mask. Right: Position-Aware Weights Computation. The weight computation part of dynamic mask attention, where in the outer loop, the $K$ and $V$ blocks are looped and loaded into SRAM, and in the inner loop, the $Q$ blocks are accessed, loaded into SRAM, and the output of the attention weight computation is written back to HBM. If the current position of the $K$ block is designated as masked in the dynamic mask, the attention weight at that position is directly filled with 0, skipping the computation at that position, forming the final position-aware sparse attention weights.
  • Figure 3: Sparsity in Language Modeling Tasks. The tasks of Copy, Select, and Induce are three essential tasks for language modeling. The Copy task requires maintaining a fixed distance between input elements and output elements, the Select task involves selectively remembering or ignoring certain elements based on the input, and the Induce task requires retrieving answers through associative recall based on context. Where the colored parts represent the tokens that the model needs to remember in the current time step of inference, the black parts represent the output tokens that the model needs to predict based on the input, and the white parts represent irrelevant tokens that can be filtered out.
  • Figure 4: Dynamic Mask Attention Structure. It demonstrates the mask structure and weight structure of Dynamic Mask Attention in the multi-head case. Unlike the same and redundant mask and weight structures in Self-Attention and State-Space, the mask structure of DMA is dynamically adjusted through content awareness, where each head's mask can be different. This allows DMA to achieve different attention weight distributions in each head, enabling the model to maximize the utilization of each subspace in multi-head attention and focus on different tokens in each head.
  • Figure 5: Scaling Laws. The perplexity performance of different self-attention variants on SmolLMCorpus at different parameter scales. For suboptimal variants like SWA and MLA, we omit them for clarity. Compared to other variants, our Dynamic Mask Attention has a Pareto advantage in performance.
  • ...and 6 more figures