Table of Contents
Fetching ...

DSV: Exploiting Dynamic Sparsity to Accelerate Large-Scale Video DiT Training

Xin Tan, Yuetao Chen, Yimin Jiang, Xing Chen, Kun Yan, Nan Duan, Yibo Zhu, Daxin Jiang, Hong Xu

TL;DR

DSV tackles the heavy computational burden of self-attention in large-scale video Diffusion Transformers by exploiting dynamic sparsity in attention. It introduces a two-stage training regime with learnable low-rank predictors to identify critical KV pairs, paired with fused kernels and query-grouping to compute sparse attention efficiently, and a hybrid sparsity-aware context parallelism to balance computation and communication. The approach yields up to 3.02x training throughput and substantial latency reductions across datasets and model scales while preserving video quality. This combination of algorithmic sparsity, efficient kernels, and adaptive parallelism enables scalable, high-definition video generation training on large GPU clusters.

Abstract

Diffusion Transformers (DiTs) have shown remarkable performance in generating high-quality videos. However, the quadratic complexity of 3D full attention remains a bottleneck in scaling DiT training, especially with high-definition, lengthy videos, where it can consume up to 95% of processing time and demand specialized context parallelism. This paper introduces DSV to accelerate video DiT training by leveraging the dynamic attention sparsity we empirically observe. DSV uses a two-stage algorithm to capture the dynamic sparsity patterns via low-rank based approximation of the original query and key. It employs custom kernels to efficiently identify critical key-value pairs and compute the sparse attention. To accommodate the new sparsity dimension, DSV adopts a hybrid sparsity-aware context parallelism that re-balances the skewed workload across attention heads and blocks due to sparsity heterogeneity. DSV achieves up to 3.02x higher training throughput, scaling to 128 GPUs and 520k token lengths, without quality loss.

DSV: Exploiting Dynamic Sparsity to Accelerate Large-Scale Video DiT Training

TL;DR

DSV tackles the heavy computational burden of self-attention in large-scale video Diffusion Transformers by exploiting dynamic sparsity in attention. It introduces a two-stage training regime with learnable low-rank predictors to identify critical KV pairs, paired with fused kernels and query-grouping to compute sparse attention efficiently, and a hybrid sparsity-aware context parallelism to balance computation and communication. The approach yields up to 3.02x training throughput and substantial latency reductions across datasets and model scales while preserving video quality. This combination of algorithmic sparsity, efficient kernels, and adaptive parallelism enables scalable, high-definition video generation training on large GPU clusters.

Abstract

Diffusion Transformers (DiTs) have shown remarkable performance in generating high-quality videos. However, the quadratic complexity of 3D full attention remains a bottleneck in scaling DiT training, especially with high-definition, lengthy videos, where it can consume up to 95% of processing time and demand specialized context parallelism. This paper introduces DSV to accelerate video DiT training by leveraging the dynamic attention sparsity we empirically observe. DSV uses a two-stage algorithm to capture the dynamic sparsity patterns via low-rank based approximation of the original query and key. It employs custom kernels to efficiently identify critical key-value pairs and compute the sparse attention. To accommodate the new sparsity dimension, DSV adopts a hybrid sparsity-aware context parallelism that re-balances the skewed workload across attention heads and blocks due to sparsity heterogeneity. DSV achieves up to 3.02x higher training throughput, scaling to 128 GPUs and 520k token lengths, without quality loss.

Paper Structure

This paper contains 38 sections, 2 equations, 22 figures, 4 tables, 1 algorithm.

Figures (22)

  • Figure 1: Overview of video DiT training. (a) The main input is the video, which is compressed by a VAE (omitted here). The timestamp is used as conditioning, and the text prompt is used as the Key-Value input in the cross-attention module. (b) Interleaved spatial-temporal attention blocks. (c) 3D full attention blocks.
  • Figure 2: Time breakdown for self-attention and other operations in different DiTs (left: 1.3B, right: 3B) with various token lengths in forward (FW, left bar) and backward (BW, right bar) computation.
  • Figure 3: Left: The attention score distribution for each query in a histogram. Right: The cumulative distribution function of the sorted attention scores for each query.
  • Figure 4: The output difference between attention with full KV and critical KV.
  • Figure 5: Distribution of critical KV positions for a query (green) at position (15, 15, 15) in the 3D latent space. The visualized keys are those that yield attention scores exceeding the 90th percentile when attending to the query.
  • ...and 17 more figures