Tree Attention: Topology-aware Decoding for Long-Context Attention on GPU clusters
Vasudev Shyam, Jonathan Pilault, Emily Shepperd, Quentin Anthony, Beren Millidge
TL;DR
Tree Attention reframes self-attention as the gradient of an energy function and uses a tree reduction to parallelize the sequence-axis reduction across multiple GPUs. This topology-aware approach yields asymptotically logarithmic communication steps and exact results, delivering substantial decoding speedups (up to 8x) with lower peak memory and reduced communication than Ring Attention. Empirical results across diverse hardware demonstrate strong gains for long-context decoding, including LLama models, and validate portability to DGX and AMD clusters. The method provides a practical pathway to efficient long-context decoding on multi-node GPU clusters without sacrificing exactness.
Abstract
Our formulation reveals that the reduction across the sequence axis can be efficiently computed in parallel through a tree reduction. Our algorithm, called Tree Attention, for parallelizing exact attention computation across multiple GPUs enables cross-device decoding to be performed asymptotically faster (up to 8x faster in our experiments) than state-of-the-art approaches such as Ring Attention, while also requiring significantly less communication volume and incurring 2x less peak memory. We demonstrate that Tree Attention speeds up decoding up to 4x on Llama 3.1-8B and can be applied to a variety of hardware and networking setups such as H100 DGX nodes, AMD MI300x nodes, and PCIe connected NVIDIA RTX 4090s. Our code is publicly available here: https://github.com/Zyphra/tree_attention
