QJL: 1-Bit Quantized JL Transform for KV Cache Quantization with Zero Overhead
Amir Zandieh, Majid Daliri, Insu Han
TL;DR
The paper tackles the KV cache memory bottleneck in autoregressive LLM inference by proposing QJL, a data-oblivious, zero-overhead quantization method that uses a Johnson-Lindenstrauss transform followed by 1-bit sign quantization. It introduces an asymmetric inner-product estimator Prod_QJL that remains unbiased, with rigorous distortion guarantees ensuring minimal impact on attention scores. The key result is that keys can be stored using $m = O(\varepsilon^{-2}\log n)$ bits per token, independent of embedding dimension, enabling a practical >5x memory reduction when quantizing KV caches to 3 bits per floating-point number while maintaining accuracy. The approach is GPU-friendly, includes outlier handling via dual quantizers and an orthogonalized JL transform, and demonstrates strong end-to-end performance gains on LongBench and Llama-2 family models, suggesting substantial practical impact for long-context generation.
Abstract
Serving LLMs requires substantial memory due to the storage requirements of Key-Value (KV) embeddings in the KV cache, which grows with sequence length. An effective approach to compress KV cache is quantization. However, traditional quantization methods face significant memory overhead due to the need to store quantization constants (at least a zero point and a scale) in full precision per data block. Depending on the block size, this overhead can add 1 or 2 bits per quantized number. We introduce QJL, a new quantization approach that consists of a Johnson-Lindenstrauss (JL) transform followed by sign-bit quantization. In contrast to existing methods, QJL eliminates memory overheads by removing the need for storing quantization constants. We propose an asymmetric estimator for the inner product of two vectors and demonstrate that applying QJL to one vector and a standard JL transform without quantization to the other provides an unbiased estimator with minimal distortion. We have developed an efficient implementation of the QJL sketch and its corresponding inner product estimator, incorporating a lightweight CUDA kernel for optimized computation. When applied across various LLMs and NLP tasks to quantize the KV cache to only 3 bits, QJL demonstrates a more than fivefold reduction in KV cache memory usage without compromising accuracy, all while achieving faster runtime. Codes are available at \url{https://github.com/amirzandieh/QJL}.
