Table of Contents
Fetching ...

The I/O Complexity of Attention, or How Optimal is Flash Attention?

Barna Saha, Christopher Ye

TL;DR

The paper addresses the I/O bottlenecks in self-attention under a two-level memory model and proves that FlashAttention is I/O-optimal for all $M \ge d^2$, while providing improved I/O strategies for $M < d^2$. It develops a compression-based lower bound framework and establishes a novel link between communication complexity and I/O complexity, showing this connection holds even in binary settings via BCH codes. In the large-cache regime, it shows no improvement over FlashAttention is possible even with fast matrix multiplication, and in the small-cache regime it characterizes the problem as equivalent to rectangular matrix multiplication, yielding a $\Theta\left( \frac{N^2 d}{\sqrt{M}} \right)$ I/O bound. The findings have practical implications for scalable transformer implementations and provide a theoretical foundation for understanding I/O limitations beyond standard matrix multiplication.

Abstract

Self-attention is at the heart of the popular Transformer architecture, yet suffers from quadratic time and memory complexity. The breakthrough FlashAttention algorithm revealed I/O complexity as the true bottleneck in scaling Transformers. Given two levels of memory hierarchy, a fast cache (e.g. GPU on-chip SRAM) and a slow memory (e.g. GPU high-bandwidth memory), the I/O complexity measures the number of accesses to memory. FlashAttention computes attention using $\frac{N^2d^2}{M}$ I/O operations where $N$ is the dimension of the attention matrix, $d$ the head-dimension and $M$ the cache size. However, is this I/O complexity optimal? The known lower bound only rules out an I/O complexity of $o(Nd)$ when $M=Θ(Nd)$, since the output that needs to be written to slow memory is $Ω(Nd)$. This leads to the main question of our work: Is FlashAttention I/O optimal for all values of $M$? We resolve the above question in its full generality by showing an I/O complexity lower bound that matches the upper bound provided by FlashAttention for any values of $M \geq d^2$ within any constant factors. Further, we give a better algorithm with lower I/O complexity for $M < d^2$, and show that it is optimal as well. Moreover, our lower bounds do not rely on using combinatorial matrix multiplication for computing the attention matrix. We show even if one uses fast matrix multiplication, the above I/O complexity bounds cannot be improved. We do so by introducing a new communication complexity protocol for matrix compression, and connecting communication complexity to I/O complexity. To the best of our knowledge, this is the first work to establish a connection between communication complexity and I/O complexity, and we believe this connection could be of independent interest and will find many more applications in proving I/O complexity lower bounds in the future.

The I/O Complexity of Attention, or How Optimal is Flash Attention?

TL;DR

The paper addresses the I/O bottlenecks in self-attention under a two-level memory model and proves that FlashAttention is I/O-optimal for all , while providing improved I/O strategies for . It develops a compression-based lower bound framework and establishes a novel link between communication complexity and I/O complexity, showing this connection holds even in binary settings via BCH codes. In the large-cache regime, it shows no improvement over FlashAttention is possible even with fast matrix multiplication, and in the small-cache regime it characterizes the problem as equivalent to rectangular matrix multiplication, yielding a I/O bound. The findings have practical implications for scalable transformer implementations and provide a theoretical foundation for understanding I/O limitations beyond standard matrix multiplication.

Abstract

Self-attention is at the heart of the popular Transformer architecture, yet suffers from quadratic time and memory complexity. The breakthrough FlashAttention algorithm revealed I/O complexity as the true bottleneck in scaling Transformers. Given two levels of memory hierarchy, a fast cache (e.g. GPU on-chip SRAM) and a slow memory (e.g. GPU high-bandwidth memory), the I/O complexity measures the number of accesses to memory. FlashAttention computes attention using I/O operations where is the dimension of the attention matrix, the head-dimension and the cache size. However, is this I/O complexity optimal? The known lower bound only rules out an I/O complexity of when , since the output that needs to be written to slow memory is . This leads to the main question of our work: Is FlashAttention I/O optimal for all values of ? We resolve the above question in its full generality by showing an I/O complexity lower bound that matches the upper bound provided by FlashAttention for any values of within any constant factors. Further, we give a better algorithm with lower I/O complexity for , and show that it is optimal as well. Moreover, our lower bounds do not rely on using combinatorial matrix multiplication for computing the attention matrix. We show even if one uses fast matrix multiplication, the above I/O complexity bounds cannot be improved. We do so by introducing a new communication complexity protocol for matrix compression, and connecting communication complexity to I/O complexity. To the best of our knowledge, this is the first work to establish a connection between communication complexity and I/O complexity, and we believe this connection could be of independent interest and will find many more applications in proving I/O complexity lower bounds in the future.
Paper Structure (19 sections, 26 theorems, 34 equations, 3 figures, 1 algorithm)

This paper contains 19 sections, 26 theorems, 34 equations, 3 figures, 1 algorithm.

Key Result

Theorem 1.1

The I/O complexity of attention with standard matrix multiplication is

Figures (3)

  • Figure 1: Computational Graph for the Attention Mechanism.
  • Figure 2: A single summation tree with $d = 4$. Orange vertices denote inputs from $Q$. Yellow vertices denote inputs from $K$. Grey and green vertices denote level-1 vertices. Observe that these are disjoint for each summation tree. The green vertex specifically denotes an entry in $QK^T$. Solid edges denote multiplications and dotted edges denote additions. Vertices in the blue box denote vertices in $L_1$.
  • Figure 3: Two examples of summation trees containing dominator and minimum vertices. Blue vertices denote elements of the partition $V_i$.

Theorems & Definitions (54)

  • Theorem 1.1
  • Theorem 1.1
  • Theorem 1.1
  • Definition 2.1: Red-Blue Pebble Game redblue1981
  • Definition 2.2: I/O Complexity redblue1981
  • Definition 2.3: Dominator Set
  • Definition 2.4: Minimum Set
  • Definition 2.5: Vertex Subset Dependence
  • Definition 2.6: $M$-partition redblue1981
  • Lemma 2.7: Theorem 3.1 of redblue1981
  • ...and 44 more