Table of Contents
Fetching ...

Trainable Log-linear Sparse Attention for Efficient Diffusion Transformers

Yifan Zhou, Zeqi Xiao, Tianyi Wei, Shuai Yang, Xingang Pan

TL;DR

This work tackles the quadratic cost of self-attention in diffusion transformers by introducing Log-linear Sparse Attention (LLSA), a trainable hierarchical Top-K mechanism that reduces attention complexity to $O(N\log N)$ via multi-level compression and selective KV enrichment. A high-performance GPU kernel enables end-to-end sparse computation without dense masks, enabling efficient training and inference on long pixel-space token sequences. Empirical results on high-resolution pixel-space DiTs show substantial throughput gains and competitive, or superior, generation quality compared to prior trainable sparse attention methods, including FFHQ and ImageNet benchmarks. The proposed Hierarchical KV Enrichment and efficient backward indexing contribute to scalable diffusion models, paving the way for longer sequences and higher-resolution generation.

Abstract

Diffusion Transformers (DiTs) set the state of the art in visual generation, yet their quadratic self-attention cost fundamentally limits scaling to long token sequences. Recent Top-K sparse attention approaches reduce the computation of DiTs by compressing tokens into block-wise representation and selecting a small set of relevant key blocks, but still suffer from (i) quadratic selection cost on compressed tokens and (ii) increasing K required to maintain model quality as sequences grow. We identify that their inefficiency is due to the single-level design, as a single coarse level is insufficient to represent the global structure. In this paper, we introduce Log-linear Sparse Attention (LLSA), a trainable sparse attention mechanism for extremely long token sequences that reduces both selection and attention costs from quadratic to log-linear complexity by utilizing a hierarchical structure. LLSA performs hierarchical Top-K selection, progressively adopting sparse Top-K selection with the indices found at the previous level, and introduces a Hierarchical KV Enrichment mechanism that preserves global context while using fewer tokens of different granularity during attention computation. To support efficient training, we develop a high-performance GPU implementation that uses only sparse indices for both the forward and backward passes, eliminating the need for dense attention masks. We evaluate LLSA on high-resolution pixel-space image generation without using patchification and VAE encoding. LLSA accelerates attention inference by 28.27x and DiT training by 6.09x on 256x256 pixel token sequences, while maintaining generation quality. The results demonstrate that LLSA offers a promising direction for training long-sequence DiTs efficiently. Code is available at: https://github.com/SingleZombie/LLSA

Trainable Log-linear Sparse Attention for Efficient Diffusion Transformers

TL;DR

This work tackles the quadratic cost of self-attention in diffusion transformers by introducing Log-linear Sparse Attention (LLSA), a trainable hierarchical Top-K mechanism that reduces attention complexity to via multi-level compression and selective KV enrichment. A high-performance GPU kernel enables end-to-end sparse computation without dense masks, enabling efficient training and inference on long pixel-space token sequences. Empirical results on high-resolution pixel-space DiTs show substantial throughput gains and competitive, or superior, generation quality compared to prior trainable sparse attention methods, including FFHQ and ImageNet benchmarks. The proposed Hierarchical KV Enrichment and efficient backward indexing contribute to scalable diffusion models, paving the way for longer sequences and higher-resolution generation.

Abstract

Diffusion Transformers (DiTs) set the state of the art in visual generation, yet their quadratic self-attention cost fundamentally limits scaling to long token sequences. Recent Top-K sparse attention approaches reduce the computation of DiTs by compressing tokens into block-wise representation and selecting a small set of relevant key blocks, but still suffer from (i) quadratic selection cost on compressed tokens and (ii) increasing K required to maintain model quality as sequences grow. We identify that their inefficiency is due to the single-level design, as a single coarse level is insufficient to represent the global structure. In this paper, we introduce Log-linear Sparse Attention (LLSA), a trainable sparse attention mechanism for extremely long token sequences that reduces both selection and attention costs from quadratic to log-linear complexity by utilizing a hierarchical structure. LLSA performs hierarchical Top-K selection, progressively adopting sparse Top-K selection with the indices found at the previous level, and introduces a Hierarchical KV Enrichment mechanism that preserves global context while using fewer tokens of different granularity during attention computation. To support efficient training, we develop a high-performance GPU implementation that uses only sparse indices for both the forward and backward passes, eliminating the need for dense attention masks. We evaluate LLSA on high-resolution pixel-space image generation without using patchification and VAE encoding. LLSA accelerates attention inference by 28.27x and DiT training by 6.09x on 256x256 pixel token sequences, while maintaining generation quality. The results demonstrate that LLSA offers a promising direction for training long-sequence DiTs efficiently. Code is available at: https://github.com/SingleZombie/LLSA

Paper Structure

This paper contains 25 sections, 4 equations, 8 figures, 5 tables, 2 algorithms.

Figures (8)

  • Figure 1: Comparison between a general Top-$K$ sparse attention and our Log-linear Sparse Attention (LLSA). In the example, we use a token sequence of length $N=8$, block size $B=2$, Top-$K$ parameter $K=1$. To reduce the complexity of the selection stage from $O(N^2)$ to $O(N)$, we extend single-level selection to $O(\log N)$ levels. To achieve this, we compute the Top-$K$ of the full sequence on the coarsest level and recursively compute the sparse Top-$K$ on the remaining levels. To preserve the global context for attention, we enrich the key, value sets for each query with coarse tokens of length $O(K \log N)$ found in the selection stage.
  • Figure 2: Illustration of index reordering. The default raster indices do not effectively cluster similar pixels during 1D pooling, while using index ordering guarantees that similar pixels receive neighboring 1D indices.
  • Figure 3: Acceleration ratio of different attention methods compared to PyTorch Attention (FlashAttention2). We evaluate training and inference with block size $B \in \{16,64\}$ across varying sequence lengths on an H200 GPU.
  • Figure 4: The throughput of sparse key-value backward. Experiments are conducted on an H200 GPU using tokens with $64$ heads and head dimension $64$. We set $K=8$ and $B=16$ for sparse Top-$K$ attention.
  • Figure 5: The FID curves of different training strategies. Compared to training from scratch, starting from a model pretrained on low-resolution data significantly reduces training cost.
  • ...and 3 more figures