SubGen: Token Generation in Sublinear Time and Memory
Amir Zandieh, Insu Han, Vahab Mirrokni, Amin Karbasi
TL;DR
SubGen addresses the linear growth of memory and compute in autoregressive decoding caused by KV caching in transformers. It presents a streaming attention framework that achieves sublinear memory and time in the context length $n$ by (i) maintaining a sublinear set of key-value samples via reservoir sampling, and (ii) approximating the softmax denominator using a clustering-based data structure under the assumption that the keys are $(m,oldsymbol{\ delta})$-clusterable with $m = o(n)$. The authors prove that, for $t = oldsymbol{igO}(oldsymbol{rac{1}{oldsymbol{ u}^2}} e^{2oldsymbol{ u}oldsymbol r} oldsymbol{ ext{log}}~n)$ and $s = oldsymbol{igO}(oldsymbol{rac{1}{oldsymbol{ u}^2}} d)$, the estimator ${oldsymbol z}_n$ satisfies $ig Vert {oldsymbol z}_n - ext{Attn}(oldsymbol q_n,oldsymbol K_n,oldsymbol V_n) ig Vert_2 \\le oldsymbol{ u} ig Vert ext{softmax}(oldsymbol K_n oldsymbol q_n) ig Vert_2 ig Vert oldsymbol V_n ig Vert_{op}$ with high probability. Under $m = o(n)$ this yields memory and runtime $oldsymbol{O}(d (m t + s)) = oldsymbol{O}(d n^{1-oldsymbol{ } )}$, i.e., sublinear in context length. Empirically, SubGen outperforms prior KV-cache compression methods on long-context QA and line-retrieval tasks, illustrating practical impact for efficient long-context generation.
Abstract
Despite the significant success of large language models (LLMs), their extensive memory requirements pose challenges for deploying them in long-context token generation. The substantial memory footprint of LLM decoders arises from the necessity to store all previous tokens in the attention module, a requirement imposed by key-value (KV) caching. In this work, our focus is on developing an efficient compression technique for the KV cache. Empirical evidence indicates a significant clustering tendency within key embeddings in the attention module. Building on this key insight, we have devised a novel caching method with sublinear complexity, employing online clustering on key tokens and online $\ell_2$ sampling on values. The result is a provably accurate and efficient attention decoding algorithm, termed SubGen. Not only does this algorithm ensure a sublinear memory footprint and sublinear time complexity, but we also establish a tight error bound for our approach. Empirical evaluations on long-context question-answering tasks demonstrate that SubGen significantly outperforms existing and state-of-the-art KV cache compression methods in terms of performance and efficiency.
