Table of Contents
Fetching ...

Longer Attention Span: Increasing Transformer Context Length with Sparse Graph Processing Techniques

Nathaniel Tomczak, Sanmukh Kuppannagari

TL;DR

This work tackles the quadratic memory and time growth of transformer attention by reframing attention as a graph computation, where tokens are nodes and the attention mask defines edges. It introduces six work-optimal graph-processing algorithms that realize true sparsity, performing only the necessary computations for arbitrary masks and using online softmax to avoid storing the full attention matrix. Empirically, the authors demonstrate substantial speedups over dense baselines (e.g., FlashAttention) for very long sequences and achieve context lengths up to $160\times 10^6$ on a single NVIDIA A100 80GB GPU, with a PyTorch back-end for seamless integration. The results indicate that combining graph-based sparsity with online softmax can unlock extremely long context lengths, enabling scalable training for domains requiring massive sequence lengths such as genomics and large-scale language modeling.

Abstract

Transformers have demonstrated great success in numerous domains including natural language processing and bioinformatics. This success stems from the use of the attention mechanism by these models in order to represent and propagate pairwise interactions between individual tokens of sequential data. However, the primary limitation of this operation is its quadratic memory and time complexity in relation to the input's context length - the length of a sequence over which the interactions need to be captured. This significantly limits the length of sequences that can be inferred upon by these models. Extensive research has been conducted to reduce the number of pairwise interactions to sub-quadratic in relation to the context length by introducing sparsity into the attention mechanism through the development of sparse attention masks. However, efficient implementations that achieve "true sparsity" are lacking. In this work, we address this issue by proposing a graph computing view of attention where tokens are perceived as nodes of the graph and the attention mask determines the edges of the graph. Using this view, we develop graph processing algorithms to implement the attention mechanism. Both theoretically and empirically, we demonstrate that our algorithms only perform the needed computations, i.e., they are work optimal. We also perform extensive experimentation using popular attention masks to explore the impact of sparsity on execution time and achievable context length. Our experiments demonstrate significant speedups in execution times compared to state-of-the-art attention implementations such as FlashAttention for large sequence lengths. We also demonstrate that our algorithms are able to achieve extremely long sequence lengths of as high as 160 million on a single NVIDIA A100 GPU (SXM4 80GB).

Longer Attention Span: Increasing Transformer Context Length with Sparse Graph Processing Techniques

TL;DR

This work tackles the quadratic memory and time growth of transformer attention by reframing attention as a graph computation, where tokens are nodes and the attention mask defines edges. It introduces six work-optimal graph-processing algorithms that realize true sparsity, performing only the necessary computations for arbitrary masks and using online softmax to avoid storing the full attention matrix. Empirically, the authors demonstrate substantial speedups over dense baselines (e.g., FlashAttention) for very long sequences and achieve context lengths up to on a single NVIDIA A100 80GB GPU, with a PyTorch back-end for seamless integration. The results indicate that combining graph-based sparsity with online softmax can unlock extremely long context lengths, enabling scalable training for domains requiring massive sequence lengths such as genomics and large-scale language modeling.

Abstract

Transformers have demonstrated great success in numerous domains including natural language processing and bioinformatics. This success stems from the use of the attention mechanism by these models in order to represent and propagate pairwise interactions between individual tokens of sequential data. However, the primary limitation of this operation is its quadratic memory and time complexity in relation to the input's context length - the length of a sequence over which the interactions need to be captured. This significantly limits the length of sequences that can be inferred upon by these models. Extensive research has been conducted to reduce the number of pairwise interactions to sub-quadratic in relation to the context length by introducing sparsity into the attention mechanism through the development of sparse attention masks. However, efficient implementations that achieve "true sparsity" are lacking. In this work, we address this issue by proposing a graph computing view of attention where tokens are perceived as nodes of the graph and the attention mask determines the edges of the graph. Using this view, we develop graph processing algorithms to implement the attention mechanism. Both theoretically and empirically, we demonstrate that our algorithms only perform the needed computations, i.e., they are work optimal. We also perform extensive experimentation using popular attention masks to explore the impact of sparsity on execution time and achievable context length. Our experiments demonstrate significant speedups in execution times compared to state-of-the-art attention implementations such as FlashAttention for large sequence lengths. We also demonstrate that our algorithms are able to achieve extremely long sequence lengths of as high as 160 million on a single NVIDIA A100 GPU (SXM4 80GB).

Paper Structure

This paper contains 33 sections, 2 equations, 6 figures, 3 tables, 1 algorithm.

Figures (6)

  • Figure 1: Visualization of the attention mechanism. $w_{ij}$ is the similarity of Query $i$ and Key $j$. The similarities are used to weight the Values.
  • Figure 2: Visualizing different masks from the Longformer (left and center) and BigBird (right) transformers. The white cells indicate masked terms that are not considered. The black cells are local attention, which can be dilated. The blue cells correspond to global attention. The orange cells are uniform random attention.
  • Figure 3: Plotting the average log runtime performance from 15 benchmark runs of our algorithms and PyTorch's C++ back-end SDP attention implementation across three NVIDIA GPUs as the context length ($L$), embedded dimension ($d_k$), and sparsity factor ($S_f$) vary. $L$ varies between each plot, increasing from left-to-right in a row (corresponding to a decrease in sparsity). $d_k$ varies by color within a plot. Log $S_f$ increases within a plot along the x-axis. (a) shows the runtimes for the NVIDIA V100 system, (b) has the runtimes for the NVIDIA L40 system, and (c) presents the runtimes for the NVIDIA A100 system.
  • Figure 4: Plots showing the trend in maximum log context length ($L$) achievable by different algorithms as the log sparsity factor ($S_f$) increases along the x-axis. This is for single-headed attention on one NVIDIA A100 GPU. The left plots showcase performance for 32-bit floating-point values and the right for 16-bit floating-point values. The embedded dimension changes between (a) and (b). (a) shows $L$ for $d_k = 64$ and (b) shows $L$ for $d_k = 128$.
  • Figure 5: Plotting the average log runtime performance from 15 benchmark runs of PyTorch's FlashAttention implementation against our local attention with either constant window size (left plot) or constant sparsity (right plot). The x-axis shows increasing context length from left-to-right.
  • ...and 1 more figures