Table of Contents
Fetching ...

Streaming Attention Approximation via Discrepancy Theory

Insu Han, Michael Kapralov, Ekaterina Kochetkova, Kshiteej Sheth, Amir Zandieh

TL;DR

This work tackles the memory bottleneck of long-context generation in Transformer-based models by introducing BalanceKV, a streaming attention approximation method grounded in discrepancy theory. It constructs SoftmaxBalance to balance exponentiated key-query interactions, implemented in streaming via BalanceKV and a MergeAndReduce framework, yielding sublinear memory usage with provable accuracy for Attn$(q_j,K_j,V_j)$. A lower bound via INDEX shows inherent space-accuracy trade-offs. Empirically, BalanceKV improves end-to-end performance on long-context benchmarks such as LongBench and Needle-In-A-Haystack, while offering competitive efficiency relative to state-of-the-art cache-compression methods.

Abstract

Large language models (LLMs) have achieved impressive success, but their high memory requirements present challenges for long-context token generation. In this paper we study the streaming complexity of attention approximation, a key computational primitive underlying token generation. Our main contribution is BalanceKV, a streaming algorithm for $ε$-approximating attention computations based on geometric process for selecting a balanced collection of Key and Value tokens as per Banaszczyk's vector balancing theory. We complement our algorithm with space lower bounds for streaming attention computation. Besides strong theoretical guarantees, BalanceKV exhibits empirically validated performance improvements over existing methods, both for attention approximation and end-to-end performance on various long context benchmarks.

Streaming Attention Approximation via Discrepancy Theory

TL;DR

This work tackles the memory bottleneck of long-context generation in Transformer-based models by introducing BalanceKV, a streaming attention approximation method grounded in discrepancy theory. It constructs SoftmaxBalance to balance exponentiated key-query interactions, implemented in streaming via BalanceKV and a MergeAndReduce framework, yielding sublinear memory usage with provable accuracy for Attn. A lower bound via INDEX shows inherent space-accuracy trade-offs. Empirically, BalanceKV improves end-to-end performance on long-context benchmarks such as LongBench and Needle-In-A-Haystack, while offering competitive efficiency relative to state-of-the-art cache-compression methods.

Abstract

Large language models (LLMs) have achieved impressive success, but their high memory requirements present challenges for long-context token generation. In this paper we study the streaming complexity of attention approximation, a key computational primitive underlying token generation. Our main contribution is BalanceKV, a streaming algorithm for -approximating attention computations based on geometric process for selecting a balanced collection of Key and Value tokens as per Banaszczyk's vector balancing theory. We complement our algorithm with space lower bounds for streaming attention computation. Besides strong theoretical guarantees, BalanceKV exhibits empirically validated performance improvements over existing methods, both for attention approximation and end-to-end performance on various long context benchmarks.

Paper Structure

This paper contains 34 sections, 11 theorems, 50 equations, 4 figures, 2 tables, 3 algorithms.

Key Result

Theorem 3.1

For any $r, \varepsilon > 0$, any positive integers $n,d$, any set of tokens $(q_1, k_1, v_1), (q_2, k_2, v_2), \ldots, \\ (q_n, k_n,v_n)$ where $q_j, k_j, v_j \in \mathbb{R}^d$ satisfy $\|q_j\|_2, \|k_j\|_2 \leq r$ for all $j$, consider an invocation of alg:main with Then alg:main outputs a vector $z_j$ satisfying eq:objective with probability at least $1 - 1/\text{poly}(n)$ at every step $j$ of

Figures (4)

  • Figure 1: Comparison of relative errors across different layers of $\mathtt{Llama}$-$\mathtt{3.1}$-$\mathtt{8B}$-$\mathtt{Instruct}$ (left) and $\mathtt{Ministral}$-$\mathtt{8B}$-$\mathtt{Instruct}$-$\mathtt{2410}$ (right) on TriviaQA dataset.
  • Figure 2: Illustration of the tree structure of MergeAndReduce
  • Figure 3: Runtime and relative error for across different layers and block sizes. In each figure the rows are corresponding to various batch sizes and columns corresponding to various compression rates
  • Figure 4: Comparison of performance on Needle in a Haystack task using $\mathtt{Llama}$-$\mathtt{3.1}$-$\mathtt{8B}$-$\mathtt{Instruct}$. The methods corresponding to figures from top to bottom are StreamingLLM, SnapKV, PyramidKV, Unif. Sampling and \ref{['alg:main']} respectively.

Theorems & Definitions (19)

  • Theorem 3.1
  • Theorem 3.2
  • Theorem 3.3
  • Theorem 3.4
  • proof : Proof of \ref{['thm:main-theorem']}
  • Theorem A.1: Theorem 1.1 in ALS21
  • proof : Proof of \ref{['thm:BALANCE-vectors']}
  • proof
  • Theorem C.1
  • Definition C.2: The INDEX problem
  • ...and 9 more