Table of Contents
Fetching ...

Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters

Vasudev Shyam, Jonathan Pilault, Emily Shepperd, Quentin Anthony, Beren Millidge

TL;DR

Tree Attention reframes self-attention as the gradient of an energy function and uses a tree reduction to parallelize the sequence-axis reduction across multiple GPUs. This topology-aware approach yields asymptotically logarithmic communication steps and exact results, delivering substantial decoding speedups (up to 8x) with lower peak memory and reduced communication than Ring Attention. Empirical results across diverse hardware demonstrate strong gains for long-context decoding, including LLama models, and validate portability to DGX and AMD clusters. The method provides a practical pathway to efficient long-context decoding on multi-node GPU clusters without sacrificing exactness.

Abstract

Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, called Tree Attention, for parallelizing exact attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8x faster in our experiments) than state-of-the-art approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2x less peak memory. We demonstrate that Tree Attention speeds up decoding up to 4x on Llama 3.1-8B and can be applied to a variety of hardware and networking setups such as H100 DGX nodes, AMD MI300x nodes, and PCIe connected NVIDIA RTX 4090s. Our code is publicly available here: https://github.com/Zyphra/tree_attention

Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters

TL;DR

Tree Attention reframes self-attention as the gradient of an energy function and uses a tree reduction to parallelize the sequence-axis reduction across multiple GPUs. This topology-aware approach yields asymptotically logarithmic communication steps and exact results, delivering substantial decoding speedups (up to 8x) with lower peak memory and reduced communication than Ring Attention. Empirical results across diverse hardware demonstrate strong gains for long-context decoding, including LLama models, and validate portability to DGX and AMD clusters. The method provides a practical pathway to efficient long-context decoding on multi-node GPU clusters without sacrificing exactness.

Abstract

Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, called Tree Attention, for parallelizing exact attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8x faster in our experiments) than state-of-the-art approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2x less peak memory. We demonstrate that Tree Attention speeds up decoding up to 4x on Llama 3.1-8B and can be applied to a variety of hardware and networking setups such as H100 DGX nodes, AMD MI300x nodes, and PCIe connected NVIDIA RTX 4090s. Our code is publicly available here: https://github.com/Zyphra/tree_attention
Paper Structure (26 sections, 1 theorem, 42 equations, 4 figures, 2 tables, 3 algorithms)

This paper contains 26 sections, 1 theorem, 42 equations, 4 figures, 2 tables, 3 algorithms.

Key Result

Theorem 1

The time complexity of a reduction operation involving an associative function, such as $\textrm{logsumexp}_a$ or $\max_a$, over an array of size $N$ using $p$ parallel processors is $O\left(\frac{N}{p} + \log p\right)$. When the number of processors $p$ is equal to $N$, the time complexity is reduc

Figures (4)

  • Figure 1: Ring and Tree Attention Topologies. Due to the associative properties of the logsumexp and max operations of Tree Attention (Fig. \ref{['fig:tree']}), is possible to structure the reduction across the sequence as a tree, requiring asymptotically fewer communication steps than Ring Attention (Fig. \ref{['fig:ring']}) as well as less memory and communications volume.
  • Figure 2: NCCL Send/Recv between two H100 GPUs intra-node and inter-node. GPU clusters offer a two-tier topology where intra-node bandwidth is significantly higher than inter-node. Algorithms such as Tree Attention exploit this topology by reducing inter-node communication requirements, enabling better overlap of communication with computation.
  • Figure 3: Execution time of 16-head Tree Attention vs Ring Attention for different sizes of GPU cluster (from 1 to 16 H100 DGX nodes). Relative execution times are indexed to the Ring Attention times at a sequence length of 80k tokens.
  • Figure 4: Peak memory usage of a single attention block with Tree Attention vs Ring Attention when sharded between two RTX 4090s. Results were taken using the JAX memory profiler on one GPU. The difference in peak memory scales with hidden size and sequence length.

Theorems & Definitions (2)

  • Theorem 1
  • proof