Table of Contents
Fetching ...

On Fine-Grained I/O Complexity of Attention Backward Passes

Xiaoyu Li, Yingyu Liang, Zhenmei Shi, Zhao Song, Song Yue, Jiahao Zhang

TL;DR

This work addresses the I/O bottlenecks of attention in long-context transformers by analyzing data movement between a fast cache and large memory using the red-blue pebble game. It derives a tight backward-pass I/O bound that scales as $\Theta\left(\min\left\{\frac{n^2 d^2 + n d^3}{M}, \frac{n^2 d + n d^2}{\sqrt{M}}\right\}\right)$ with a crossover at $M=\Theta(d^2)$, and shows FlashAttention is optimal in the large-cache regime while introducing a new efficient small-cache algorithm. The paper also provides fine-grained lower bounds for sparse attention and integrates these results with existing forward-pass analyses to yield a complete I/O complexity picture for attention. These findings offer practical guidance for designing hardware-aware, memory-efficient training and inference pipelines for large language models.

Abstract

Large Language Models (LLMs) exhibit exceptional proficiency in handling extensive context windows in natural language. Nevertheless, the quadratic scaling of attention computation relative to sequence length creates substantial efficiency bottlenecks, necessitating the development of I/O-optimized algorithms. In this work, we conduct a systematic examination of the I/O complexity inherent in attention mechanisms, with a specific emphasis on the backward pass under both small and large cache settings. By leveraging the red-blue pebble game framework, we derive tight bounds for I/O complexity across the full spectrum of cache sizes. We validate that FlashAttention, one of the current industry standards, achieves optimality in the large-cache scenario for both forward and backward passes. Conversely, for small-cache environments, we introduce a novel algorithm that outperforms contemporary methods and successfully attains theoretical tight bounds. Furthermore, we expand our investigation to include sparse attention by establishing granular lower bounds for both forward and backward passes across all cache configurations. Ultimately, our results solidify the theoretical framework regarding I/O complexity in attention mechanisms, providing critical guidance for the development of efficient LLM training and inference systems.

On Fine-Grained I/O Complexity of Attention Backward Passes

TL;DR

This work addresses the I/O bottlenecks of attention in long-context transformers by analyzing data movement between a fast cache and large memory using the red-blue pebble game. It derives a tight backward-pass I/O bound that scales as with a crossover at , and shows FlashAttention is optimal in the large-cache regime while introducing a new efficient small-cache algorithm. The paper also provides fine-grained lower bounds for sparse attention and integrates these results with existing forward-pass analyses to yield a complete I/O complexity picture for attention. These findings offer practical guidance for designing hardware-aware, memory-efficient training and inference pipelines for large language models.

Abstract

Large Language Models (LLMs) exhibit exceptional proficiency in handling extensive context windows in natural language. Nevertheless, the quadratic scaling of attention computation relative to sequence length creates substantial efficiency bottlenecks, necessitating the development of I/O-optimized algorithms. In this work, we conduct a systematic examination of the I/O complexity inherent in attention mechanisms, with a specific emphasis on the backward pass under both small and large cache settings. By leveraging the red-blue pebble game framework, we derive tight bounds for I/O complexity across the full spectrum of cache sizes. We validate that FlashAttention, one of the current industry standards, achieves optimality in the large-cache scenario for both forward and backward passes. Conversely, for small-cache environments, we introduce a novel algorithm that outperforms contemporary methods and successfully attains theoretical tight bounds. Furthermore, we expand our investigation to include sparse attention by establishing granular lower bounds for both forward and backward passes across all cache configurations. Ultimately, our results solidify the theoretical framework regarding I/O complexity in attention mechanisms, providing critical guidance for the development of efficient LLM training and inference systems.

Paper Structure

This paper contains 38 sections, 33 theorems, 25 equations, 3 figures, 1 table, 9 algorithms.

Key Result

Theorem 1.1

Let $n$ be the sequence length, $d$ the head dimension, and $M$ the cache size. The I/O complexity of attention backward computation under standard matrix multiplication is $\Theta \left(\min \left\{\frac{n^2d^2 + nd^3}{M}, \frac{n^2d + nd^2}{\sqrt{M}} \right\}\right).$

Figures (3)

  • Figure 1: Attention backward I/O complexity comparison. The $x$-axis is the cache size, and the $y$-axis is the I/O complexity. The red line represents our tight upper/lower bound (Theorem \ref{['thm:main']}), and the blue dash denotes the upper bound for FlashAttention dfe+22. The cross point is $M=\Theta(d^2)$, the dividing point of large cache and small cache settings. The results show that FlashAttention is optimal when $M = \Omega(d^2)$.
  • Figure 2: The computational graph for attention forward and backward. The blue boxes are input matrices, the gray boxes are intermediate matrices, the green box is the forward output, and the orange box is the final gradient matrix. Here, $A_1,A_2,A_3$ denote the previous inputs, $\mathrm{d} O$ denotes the upstream gradient, and $X,Y$ denote the attention weights. More detailed definitions of each variables can be found in Section \ref{['sec:preliminary']} and \ref{['sec:preli']}.
  • Figure 3: This diagram shows a summation tree with $d=2$ in the computational graph for the backward passes of attention using standard matrix multiplication. The orange and green nodes represent the input nodes of the level-$1$ summation tree. The brown nodes, along with the blue nodes (output from the level-$1$ summation tree), serve as inputs for the level-$2$ summation tree. The purple nodes represent the target output. When $d$ gets larger, the summation tree will expand with additional layers, where each new layer introduces intermediate nodes that represent the sums of pairs of nodes from the previous layer, i.e., there will be a total $1 + \log_2 d$ layer in total.

Theorems & Definitions (75)

  • Theorem 1.1: Main result
  • Theorem 1.2: Lower bound for sparse attention forward and backward, informal version of Theorem \ref{['thm:sparse_attn_io:formal']}
  • Definition 3.1: Attention forward computation
  • Definition 3.2: Attention backward gradient
  • Remark 3.3
  • Definition 3.4: Red-blue pebble game hk81
  • Definition 3.5: I/O complexity hk81
  • Theorem 4.1: Large cache upper bound, informal version of Theorem \ref{['thm:attn_grad_large_cache']}
  • Theorem 4.2: Large cache lower bound, informal version of Theorem \ref{['thm:large_cache_lower_bound:formal']}
  • Theorem 4.3: Small cache upper bound, informal version of Theorem \ref{['thm:attn_grad_small_cache']}
  • ...and 65 more