Table of Contents
Fetching ...

Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels

Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Sepp Hochreiter

TL;DR

The paper tackles the inefficiency of long-context sequence modeling in linear RNNs by introducing Tiled Flash Linear Attention (TFLA), a two-level sequence-parallel kernel that decouples chunk-level recurrence from intra-chunk tiling. Applied to the mLSTM/xLSTM, TFLA enables arbitrarily large chunk sizes, substantially boosting arithmetic intensity and reducing memory IO compared to prior FlashLinearAttention approaches. It further proposes a faster mLSTM variant with a sigmoid input gate (mLSTMsig) and demonstrates through language modeling and kernel benchmarks that both mLSTMexp and mLSTMsig achieve competitive performance with state-of-the-art efficiency, outperforming Flash Attention and related kernels on long sequences. The work also analyzes normalization and gate initialization effects on transfer behavior and training stability, providing practical guidance for stable, scalable training of long-context models. Overall, TFLA offers a scalable, hardware-aware primitive for efficient long-context linear RNNs, with potential broad impact for future dense decoder-style models and beyond.

Abstract

Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient Flash Attention kernels (Dao, 2024). Leveraging the chunkwise-parallel formulation of linear RNNs, Flash Linear Attention (FLA) (Yang & Zhang, 2024) shows that linear RNN kernels are faster than Flash Attention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This leads to low arithmetic intensity and causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes and high arithmetic intensity by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM (Beck et al., 2024). Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.

Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels

TL;DR

The paper tackles the inefficiency of long-context sequence modeling in linear RNNs by introducing Tiled Flash Linear Attention (TFLA), a two-level sequence-parallel kernel that decouples chunk-level recurrence from intra-chunk tiling. Applied to the mLSTM/xLSTM, TFLA enables arbitrarily large chunk sizes, substantially boosting arithmetic intensity and reducing memory IO compared to prior FlashLinearAttention approaches. It further proposes a faster mLSTM variant with a sigmoid input gate (mLSTMsig) and demonstrates through language modeling and kernel benchmarks that both mLSTMexp and mLSTMsig achieve competitive performance with state-of-the-art efficiency, outperforming Flash Attention and related kernels on long sequences. The work also analyzes normalization and gate initialization effects on transfer behavior and training stability, providing practical guidance for stable, scalable training of long-context models. Overall, TFLA offers a scalable, hardware-aware primitive for efficient long-context linear RNNs, with potential broad impact for future dense decoder-style models and beyond.

Abstract

Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient Flash Attention kernels (Dao, 2024). Leveraging the chunkwise-parallel formulation of linear RNNs, Flash Linear Attention (FLA) (Yang & Zhang, 2024) shows that linear RNN kernels are faster than Flash Attention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This leads to low arithmetic intensity and causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes and high arithmetic intensity by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM (Beck et al., 2024). Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.

Paper Structure

This paper contains 112 sections, 51 equations, 25 figures, 18 tables, 1 algorithm.

Figures (25)

  • Figure 1: Tiled Flash Linear Attention (TFLA) consists of a recurrent kernel and a parallel kernel, which process the input sequence in chunks $\bm{Q} \bm{K} \bm{V} ^{(k)}$ (1st level of sequence parallelism). The recurrent kernel materializes the memory state $\bm{C}_{k-1}$ for each chunk. The parallel kernel computes the output states $\bm{\mathrm{H}}^{(k)}$ for all chunks. TFLA uses tiling for the 3 matrix-multiplications in the parallel kernel (2nd level of sequence parallelism) to fully utilize the hardware and to prevent materialization of many memory states.
  • Figure 2: Illustration of the chunkwise gates $\bm{\mathrm{a}}_k$, $\bm{\mathrm{b}}_k$ and $\mathrm{g}_k$ with chunk size $L=4$. Each arrow denotes an element in the gate vectors. See Figure \ref{['fig:mlstm_chunkwise_gates']} in Appendix \ref{['app:detailed_chunkwise_parallel_formulation']} for more details.
  • Figure 3: TFLA Forward Pass Tiling. We loop over $B_{Lkv}$ and $B_{dqk}$ (indicated by arrows) and parallelize over $B_{Lhq}$ and $B_{dhv}$ (indicated by dashed lines) blocks. $\bigoplus$ denotes block-wise accumulation.
  • Figure 4: Transfer behavior of the mLSTM before and after the RMS-norm layer ($\epsilon=$1e-6) for different input and forget gate values. The color shows the gain of the mLSTM defined in (\ref{['eq:gain']}). After the norm layer mLSTMexp and mLSTMsig exhibit the same transfer behavior.
  • Figure 5: TFLA Kernel Runtime Benchmark for embedding dimension 4096 and 65,536 tokens on NVIDIA H100 GPUs. In training, our TFLA kernels are faster than FlashAttention 3 for longer sequences and over 2x faster than Mamba 2 kernels for all sequence lengths.
  • ...and 20 more figures