Table of Contents
Fetching ...

BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences

Ao Sun, Weilin Zhao, Xu Han, Cheng Yang, Zhiyuan Liu, Chuan Shi, Maosong Sun

TL;DR

This work tackles the quadratic time and memory costs of attention for extremely long sequences by distributing attention computation across a cluster and within devices. BurstAttention combines inter-device partitioning with intra-device tiling, supported by Global Attention Optimization (GAO) and Local Attention Optimization (LAO), plus online softmax-based global aggregation and double-buffered communication to overlap computation and data transfer. GAO avoids storing full $S$ and $P$ by dynamic accumulation, while LAO exploits SRAM tiling to maximize high-bandwidth local computations; together they reduce memory and I/O demands and lower communication overhead. The approach is compatible with sparse attention and other distributed strategies, and experimental results show about 40% reduction in communication overhead and up to near 2× training speedups on very long sequences, with strong scalability on multi-GPU clusters. Overall, BurstAttention provides a practical, scalable solution for efficient long-sequence attention in both training and inference contexts.

Abstract

Effective attention modules have played a crucial role in the success of Transformer-based large language models (LLMs), but the quadratic time and memory complexities of these attention modules also pose a challenge when processing long sequences. One potential solution for the long sequence problem is to utilize distributed clusters to parallelize the computation of attention modules across multiple devices (e.g., GPUs). However, adopting a distributed approach inevitably introduces extra memory overheads to store local attention results and incurs additional communication costs to aggregate local results into global ones. In this paper, we propose a distributed attention framework named ``BurstAttention'' to optimize memory access and communication operations at both the global cluster and local device levels. In our experiments, we compare BurstAttention with other competitive distributed attention solutions for long sequence processing. The experimental results under different length settings demonstrate that BurstAttention offers significant advantages for processing long sequences compared with these competitive baselines, reducing 40% communication overheads and achieving 1.37 X speedup during training 128K sequence length on 32 X A100.

BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences

TL;DR

This work tackles the quadratic time and memory costs of attention for extremely long sequences by distributing attention computation across a cluster and within devices. BurstAttention combines inter-device partitioning with intra-device tiling, supported by Global Attention Optimization (GAO) and Local Attention Optimization (LAO), plus online softmax-based global aggregation and double-buffered communication to overlap computation and data transfer. GAO avoids storing full and by dynamic accumulation, while LAO exploits SRAM tiling to maximize high-bandwidth local computations; together they reduce memory and I/O demands and lower communication overhead. The approach is compatible with sparse attention and other distributed strategies, and experimental results show about 40% reduction in communication overhead and up to near 2× training speedups on very long sequences, with strong scalability on multi-GPU clusters. Overall, BurstAttention provides a practical, scalable solution for efficient long-sequence attention in both training and inference contexts.

Abstract

Effective attention modules have played a crucial role in the success of Transformer-based large language models (LLMs), but the quadratic time and memory complexities of these attention modules also pose a challenge when processing long sequences. One potential solution for the long sequence problem is to utilize distributed clusters to parallelize the computation of attention modules across multiple devices (e.g., GPUs). However, adopting a distributed approach inevitably introduces extra memory overheads to store local attention results and incurs additional communication costs to aggregate local results into global ones. In this paper, we propose a distributed attention framework named ``BurstAttention'' to optimize memory access and communication operations at both the global cluster and local device levels. In our experiments, we compare BurstAttention with other competitive distributed attention solutions for long sequence processing. The experimental results under different length settings demonstrate that BurstAttention offers significant advantages for processing long sequences compared with these competitive baselines, reducing 40% communication overheads and achieving 1.37 X speedup during training 128K sequence length on 32 X A100.
Paper Structure (22 sections, 7 theorems, 8 equations, 4 figures, 4 tables, 1 algorithm)

This paper contains 22 sections, 7 theorems, 8 equations, 4 figures, 4 tables, 1 algorithm.

Key Result

Theorem 2.1

In a Transformer block employing Tensor Parallelism (TP) within the Megatron-V3 framework, the total runtime $T$ is determined by the sum of communication times for all-gather and reduce-scatter operations, and the computation times for the attention (attn) and feedforward (ffn) modules, distributed

Figures (4)

  • Figure 1: BurstAttention undertakes a two-step partitioning: dividing the sequence across multiple devices (inter-device), and then splitting the subsequences within each single device (intra-device). First, BurstAttention partitions the query, key, and value across devices and pass each sliced subsequence through all devices in a ring-like communication. This allows each device to process only a local attention at a time, and avoids the burden on memory caused by processing extremely long sequence at once. By transmitting $\mathbf{K},\mathbf{V}$ and aggregating local attention results using online softmax, BurstAttention avoids storing the intermediate result $\mathbf{QK}^T$, which has quadratic memory complexity, and instead recomputes it during the backward pass, which we call global attention optimization (GAO). BurstAttention further partitions the subsequences into smaller tiles, aiming to perform block-wise computations within local attention. This can utilize the high bandwidth of SRAM while minimizing access to the lower bandwidth HBM, which we call local attention optimization (LAO). Also, by using double-buffer, the communication can be overlapped with computation in BurstAttention.
  • Figure 2: The training time and memory of LLaMA-7b on 8$\times$A100.
  • Figure 3: The training time and memory of LLaMA-7b on 32$\times$A100.
  • Figure 4: Scaling abilities on different GPU numbers and batch sizes.

Theorems & Definitions (11)

  • Theorem 2.1
  • Definition 2.2: Input Tensor and Cluster Configuration
  • Lemma 2.3: Communication Time
  • Proposition 2.4: Runtime Calculation
  • Remark 2.5
  • Theorem 3.1
  • Definition 3.2: Input Tensor and Cluster Configuration
  • Lemma 3.3: Activation Communication Time in BurstAttention
  • Lemma 3.4: Weight Communication Time in BurstAttention
  • Proposition 3.5: Runtime Calculation in BurstAttention
  • ...and 1 more