BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences
Ao Sun, Weilin Zhao, Xu Han, Cheng Yang, Zhiyuan Liu, Chuan Shi, Maosong Sun
TL;DR
This work tackles the quadratic time and memory costs of attention for extremely long sequences by distributing attention computation across a cluster and within devices. BurstAttention combines inter-device partitioning with intra-device tiling, supported by Global Attention Optimization (GAO) and Local Attention Optimization (LAO), plus online softmax-based global aggregation and double-buffered communication to overlap computation and data transfer. GAO avoids storing full $S$ and $P$ by dynamic accumulation, while LAO exploits SRAM tiling to maximize high-bandwidth local computations; together they reduce memory and I/O demands and lower communication overhead. The approach is compatible with sparse attention and other distributed strategies, and experimental results show about 40% reduction in communication overhead and up to near 2× training speedups on very long sequences, with strong scalability on multi-GPU clusters. Overall, BurstAttention provides a practical, scalable solution for efficient long-sequence attention in both training and inference contexts.
Abstract
Effective attention modules have played a crucial role in the success of Transformer-based large language models (LLMs), but the quadratic time and memory complexities of these attention modules also pose a challenge when processing long sequences. One potential solution for the long sequence problem is to utilize distributed clusters to parallelize the computation of attention modules across multiple devices (e.g., GPUs). However, adopting a distributed approach inevitably introduces extra memory overheads to store local attention results and incurs additional communication costs to aggregate local results into global ones. In this paper, we propose a distributed attention framework named ``BurstAttention'' to optimize memory access and communication operations at both the global cluster and local device levels. In our experiments, we compare BurstAttention with other competitive distributed attention solutions for long sequence processing. The experimental results under different length settings demonstrate that BurstAttention offers significant advantages for processing long sequences compared with these competitive baselines, reducing 40% communication overheads and achieving 1.37 X speedup during training 128K sequence length on 32 X A100.
