Table of Contents
Fetching ...

Optimizing Mixture of Block Attention

Guangxuan Xiao, Junxian Guo, Kasra Mazaheri, Song Han

TL;DR

The paper provides a principled, SNR-based framework to understand MoBA’s block-routing mechanism, linking retrieval accuracy to architectural choices via SNR ∝ sqrt(d/(2B)) and highlighting the benefits of smaller blocks and within-block signal clustering through a short key convolution. Guided by this theory, the authors propose FlashMoBA, a hardware-aware CUDA kernel suite that fuses top-k and gather-densify steps and enables efficient small-block MoBA by coalescing memory access and recomputing attention in the backward pass. Empirical results from training LLMs up to 1B parameters show MoBA with small blocks and key convolution can match or surpass dense attention on long-context tasks, while FlashMoBA delivers up to 14.7x speedups over FlashAttention-2 and scales to very long sequences. Overall, the work provides both a theoretical lens and a practical, scalable implementation path for deploying MoBA in real-world, long-context LLM applications.

Abstract

Mixture of Block Attention (MoBA) (Lu et al., 2025) is a promising building block for efficiently processing long contexts in LLMs by enabling queries to sparsely attend to a small subset of key-value blocks, drastically reducing computational cost. However, the design principles governing MoBA's performance are poorly understood, and it lacks an efficient GPU implementation, hindering its practical adoption. In this paper, we first develop a statistical model to analyze MoBA's underlying mechanics. Our model reveals that performance critically depends on the router's ability to accurately distinguish relevant from irrelevant blocks based on query-key affinities. We derive a signal-to-noise ratio that formally connects architectural parameters to this retrieval accuracy. Guided by our analysis, we identify two key pathways for improvement: using smaller block sizes and applying a short convolution on keys to cluster relevant signals, which enhances routing accuracy. While theoretically better, small block sizes are inefficient on GPUs. To bridge this gap, we introduce FlashMoBA, a hardware-aware CUDA kernel that enables efficient MoBA execution even with the small block sizes our theory recommends. We validate our insights by training LLMs from scratch, showing that our improved MoBA models match the performance of dense attention baselines. FlashMoBA achieves up to 14.7x speedup over FlashAttention-2 for small blocks, making our theoretically-grounded improvements practical. Code is available at: https://github.com/mit-han-lab/flash-moba.

Optimizing Mixture of Block Attention

TL;DR

The paper provides a principled, SNR-based framework to understand MoBA’s block-routing mechanism, linking retrieval accuracy to architectural choices via SNR ∝ sqrt(d/(2B)) and highlighting the benefits of smaller blocks and within-block signal clustering through a short key convolution. Guided by this theory, the authors propose FlashMoBA, a hardware-aware CUDA kernel suite that fuses top-k and gather-densify steps and enables efficient small-block MoBA by coalescing memory access and recomputing attention in the backward pass. Empirical results from training LLMs up to 1B parameters show MoBA with small blocks and key convolution can match or surpass dense attention on long-context tasks, while FlashMoBA delivers up to 14.7x speedups over FlashAttention-2 and scales to very long sequences. Overall, the work provides both a theoretical lens and a practical, scalable implementation path for deploying MoBA in real-world, long-context LLM applications.

Abstract

Mixture of Block Attention (MoBA) (Lu et al., 2025) is a promising building block for efficiently processing long contexts in LLMs by enabling queries to sparsely attend to a small subset of key-value blocks, drastically reducing computational cost. However, the design principles governing MoBA's performance are poorly understood, and it lacks an efficient GPU implementation, hindering its practical adoption. In this paper, we first develop a statistical model to analyze MoBA's underlying mechanics. Our model reveals that performance critically depends on the router's ability to accurately distinguish relevant from irrelevant blocks based on query-key affinities. We derive a signal-to-noise ratio that formally connects architectural parameters to this retrieval accuracy. Guided by our analysis, we identify two key pathways for improvement: using smaller block sizes and applying a short convolution on keys to cluster relevant signals, which enhances routing accuracy. While theoretically better, small block sizes are inefficient on GPUs. To bridge this gap, we introduce FlashMoBA, a hardware-aware CUDA kernel that enables efficient MoBA execution even with the small block sizes our theory recommends. We validate our insights by training LLMs from scratch, showing that our improved MoBA models match the performance of dense attention baselines. FlashMoBA achieves up to 14.7x speedup over FlashAttention-2 for small blocks, making our theoretically-grounded improvements practical. Code is available at: https://github.com/mit-han-lab/flash-moba.

Paper Structure

This paper contains 51 sections, 13 equations, 4 figures, 6 tables, 5 algorithms.

Figures (4)

  • Figure 1: FlashMoBA forward pass in two stages.MoBA splits keys and values into blocks; each query scores centroids of key-blocks $\tilde{\mathbf K}$ and attends only to its top-$k$ blocks (plus causally to its own block). 1) Tiled Top-$k$ Selection: a fused kernel streams tiles of $\mathbf Q$ and $\tilde{\mathbf K}$ to emit a sparse routing mask without materializing the full matrix. 2) Dense Computation: for each selected key block, queries are gathered into on-chip SRAM, computed densely with FlashAttention-2 logic, then scattered back. This gather-and-densify strategy coalesces memory, maximizes hardware utilization and makes small-block MoBA fast on GPUs (See Section \ref{['sec:cuda']} for more details).
  • Figure 2: Smaller block sizes improve WikiText perplexity and RULER accuracy (340M, $d=64$, 100B tokens). Reducing $B$ from 512 to 128 lowers ppl by 1.2 and increases RULER by 17.2%.
  • Figure 3: Latency & memory vs. length (bsz${=}2$, $B{=}128$, $k{=}8$) for MoBA (original), FlashAttention-2, and FlashMoBA. Top: end-to-end latency; bottom: peak memory. MoBA/FlashMoBA bars are decomposed (top$\to$bottom) into backward, forward, and Top-$k$ overheads. MoBA is dominated by non-attention overheads and hits OOM at 128K. By fusing tiled Top-$k$ with a gather-and-densify kernel, FlashMoBA makes overhead negligible, cuts memory, and is up to 14.7$\times$ faster than FlashAttention at long sequence lengths.
  • Figure 4: Forward-pass timing breakdown ($N{=}64\mathrm{K}$, $B{=}128$, $k{=}8$). MoBA (original) is bottlenecked by routing overheads, while FlashMoBA's fused kernels cut total time to 49 ms, outperforming FlashAttention-2.