Table of Contents
Fetching ...

QJL: 1-Bit Quantized JL Transform for KV Cache Quantization with Zero Overhead

Amir Zandieh, Majid Daliri, Insu Han

TL;DR

The paper tackles the KV cache memory bottleneck in autoregressive LLM inference by proposing QJL, a data-oblivious, zero-overhead quantization method that uses a Johnson-Lindenstrauss transform followed by 1-bit sign quantization. It introduces an asymmetric inner-product estimator Prod_QJL that remains unbiased, with rigorous distortion guarantees ensuring minimal impact on attention scores. The key result is that keys can be stored using $m = O(\varepsilon^{-2}\log n)$ bits per token, independent of embedding dimension, enabling a practical >5x memory reduction when quantizing KV caches to 3 bits per floating-point number while maintaining accuracy. The approach is GPU-friendly, includes outlier handling via dual quantizers and an orthogonalized JL transform, and demonstrates strong end-to-end performance gains on LongBench and Llama-2 family models, suggesting substantial practical impact for long-context generation.

Abstract

Serving LLMs requires substantial memory due to the storage requirements of Key-Value (KV) embeddings in the KV cache, which grows with sequence length. An effective approach to compress KV cache is quantization. However, traditional quantization methods face significant memory overhead due to the need to store quantization constants (at least a zero point and a scale) in full precision per data block. Depending on the block size, this overhead can add 1 or 2 bits per quantized number. We introduce QJL, a new quantization approach that consists of a Johnson-Lindenstrauss (JL) transform followed by sign-bit quantization. In contrast to existing methods, QJL eliminates memory overheads by removing the need for storing quantization constants. We propose an asymmetric estimator for the inner product of two vectors and demonstrate that applying QJL to one vector and a standard JL transform without quantization to the other provides an unbiased estimator with minimal distortion. We have developed an efficient implementation of the QJL sketch and its corresponding inner product estimator, incorporating a lightweight CUDA kernel for optimized computation. When applied across various LLMs and NLP tasks to quantize the KV cache to only 3 bits, QJL demonstrates a more than fivefold reduction in KV cache memory usage without compromising accuracy, all while achieving faster runtime. Codes are available at \url{https://github.com/amirzandieh/QJL}.

QJL: 1-Bit Quantized JL Transform for KV Cache Quantization with Zero Overhead

TL;DR

The paper tackles the KV cache memory bottleneck in autoregressive LLM inference by proposing QJL, a data-oblivious, zero-overhead quantization method that uses a Johnson-Lindenstrauss transform followed by 1-bit sign quantization. It introduces an asymmetric inner-product estimator Prod_QJL that remains unbiased, with rigorous distortion guarantees ensuring minimal impact on attention scores. The key result is that keys can be stored using bits per token, independent of embedding dimension, enabling a practical >5x memory reduction when quantizing KV caches to 3 bits per floating-point number while maintaining accuracy. The approach is GPU-friendly, includes outlier handling via dual quantizers and an orthogonalized JL transform, and demonstrates strong end-to-end performance gains on LongBench and Llama-2 family models, suggesting substantial practical impact for long-context generation.

Abstract

Serving LLMs requires substantial memory due to the storage requirements of Key-Value (KV) embeddings in the KV cache, which grows with sequence length. An effective approach to compress KV cache is quantization. However, traditional quantization methods face significant memory overhead due to the need to store quantization constants (at least a zero point and a scale) in full precision per data block. Depending on the block size, this overhead can add 1 or 2 bits per quantized number. We introduce QJL, a new quantization approach that consists of a Johnson-Lindenstrauss (JL) transform followed by sign-bit quantization. In contrast to existing methods, QJL eliminates memory overheads by removing the need for storing quantization constants. We propose an asymmetric estimator for the inner product of two vectors and demonstrate that applying QJL to one vector and a standard JL transform without quantization to the other provides an unbiased estimator with minimal distortion. We have developed an efficient implementation of the QJL sketch and its corresponding inner product estimator, incorporating a lightweight CUDA kernel for optimized computation. When applied across various LLMs and NLP tasks to quantize the KV cache to only 3 bits, QJL demonstrates a more than fivefold reduction in KV cache memory usage without compromising accuracy, all while achieving faster runtime. Codes are available at \url{https://github.com/amirzandieh/QJL}.
Paper Structure (11 sections, 3 theorems, 17 equations, 3 figures, 2 tables, 1 algorithm)

This paper contains 11 sections, 3 theorems, 17 equations, 3 figures, 2 tables, 1 algorithm.

Key Result

Lemma 3.2

For any vectors ${\bm q}, {\bm k} \in \mathbb{R}^d$ the expected value of the estimator $\operatorname{{\tt Prod_{QJL}}}({\bm q}, {\bm k})$ defined in prod_est is: where the expectation is over the randomness of the JL matrix ${\bm S}$ in def:asym_hash.

Figures (3)

  • Figure 1: Overview of the KV cache quantization via Quantized JL (QJL) transform
  • Figure 2: The magnitude of key cache entries for different layers of the Llama-2 model, based on an example prompt, reveals notable patterns. The coordinates of embeddings (channels) are sorted by their average magnitude over tokens. In the initial layers, no significant outlier patterns are observed. However, in the deeper layers, a few channels (approximately four) exhibit visibly larger magnitudes, indicating the presence of significant outliers. This observation highlights the importance of addressing these outliers to improve quantization accuracy and reduce distortion in the key cache.
  • Figure 3: Wall-clock time (ms) to encode a prompt and quantize the KV cache (left), generate 128 tokens for llama2 model (middle), and generate 64 tokens for llama3 model (right) using different quantization methods in a single attention layer model. The input sequence length varies from 1k to 64k. Both KIVI and QJL (ours) with 3 bits per FPN show faster decoding time than the baseline. However, KVQuant is significantly slower during both quantizing and decoding phases. QJL is the only method that can quantize Llama3, as our kernels support grouped query attention and BF16 data type. We observe the same speed for Llama3 as the exact method for generation. Note that our memory usage is at least 5-fold less than the exact method and can support all data types.

Theorems & Definitions (7)

  • Definition 3.1: QJL and inner product estimator
  • Lemma 3.2: Inner product estimator $\operatorname{{\tt Prod_{QJL}}}$ is unbiased
  • proof
  • Lemma 3.5: Distortion of inner product estimator $\operatorname{{\tt Prod_{QJL}}}$
  • proof
  • Theorem 3.6: Distortion bound on QJL key cache quantizer
  • proof