Table of Contents
Fetching ...

Keyformer: KV Cache Reduction through Key Tokens Selection for Efficient Generative Inference

Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant J. Nair, Ilya Soloveychik, Purushotham Kamath

TL;DR

Keyformer tackles the KV-cache bottleneck in long-context LLM inference by identifying and retaining a small set of key tokens, complemented by a mixture of recent tokens. It introduces a Gumbel-based logits regularization and a temperature-scheduled score function that accumulates across decoding steps to robustly select key tokens without retraining. Empirical results across GPT-J, Cerebras-GPT, and MPT show 2.1x latency reduction and 2.4x throughput gains at 50% KV-cache reduction while maintaining MLPerf-aligned accuracy, particularly excelling in long-context summarization. The work demonstrates that inference-time KV-cache reduction via probabilistic key-token selection provides substantial practical benefits with minimal degradation in quality, and it outlines thorough ablations and comparisons to prior KV-cache and sparse-attention methods.

Abstract

Transformers have emerged as the underpinning architecture for Large Language Models (LLMs). In generative language models, the inference process involves two primary phases: prompt processing and token generation. Token generation, which constitutes the majority of the computational workload, primarily entails vector-matrix multiplications and interactions with the Key-Value (KV) Cache. This phase is constrained by memory bandwidth due to the overhead of transferring weights and KV cache values from the memory system to the computing units. This memory bottleneck becomes particularly pronounced in applications that require long-context and extensive text generation, both of which are increasingly crucial for LLMs. This paper introduces "Keyformer", an innovative inference-time approach, to mitigate the challenges associated with KV cache size and memory bandwidth utilization. Keyformer leverages the observation that approximately 90% of the attention weight in generative inference focuses on a specific subset of tokens, referred to as "key" tokens. Keyformer retains only the key tokens in the KV cache by identifying these crucial tokens using a novel score function. This approach effectively reduces both the KV cache size and memory bandwidth usage without compromising model accuracy. We evaluate Keyformer's performance across three foundational models: GPT-J, Cerebras-GPT, and MPT, which employ various positional embedding algorithms. Our assessment encompasses a variety of tasks, with a particular emphasis on summarization and conversation tasks involving extended contexts. Keyformer's reduction of KV cache reduces inference latency by 2.1x and improves token generation throughput by 2.4x, while preserving the model's accuracy.

Keyformer: KV Cache Reduction through Key Tokens Selection for Efficient Generative Inference

TL;DR

Keyformer tackles the KV-cache bottleneck in long-context LLM inference by identifying and retaining a small set of key tokens, complemented by a mixture of recent tokens. It introduces a Gumbel-based logits regularization and a temperature-scheduled score function that accumulates across decoding steps to robustly select key tokens without retraining. Empirical results across GPT-J, Cerebras-GPT, and MPT show 2.1x latency reduction and 2.4x throughput gains at 50% KV-cache reduction while maintaining MLPerf-aligned accuracy, particularly excelling in long-context summarization. The work demonstrates that inference-time KV-cache reduction via probabilistic key-token selection provides substantial practical benefits with minimal degradation in quality, and it outlines thorough ablations and comparisons to prior KV-cache and sparse-attention methods.

Abstract

Transformers have emerged as the underpinning architecture for Large Language Models (LLMs). In generative language models, the inference process involves two primary phases: prompt processing and token generation. Token generation, which constitutes the majority of the computational workload, primarily entails vector-matrix multiplications and interactions with the Key-Value (KV) Cache. This phase is constrained by memory bandwidth due to the overhead of transferring weights and KV cache values from the memory system to the computing units. This memory bottleneck becomes particularly pronounced in applications that require long-context and extensive text generation, both of which are increasingly crucial for LLMs. This paper introduces "Keyformer", an innovative inference-time approach, to mitigate the challenges associated with KV cache size and memory bandwidth utilization. Keyformer leverages the observation that approximately 90% of the attention weight in generative inference focuses on a specific subset of tokens, referred to as "key" tokens. Keyformer retains only the key tokens in the KV cache by identifying these crucial tokens using a novel score function. This approach effectively reduces both the KV cache size and memory bandwidth usage without compromising model accuracy. We evaluate Keyformer's performance across three foundational models: GPT-J, Cerebras-GPT, and MPT, which employ various positional embedding algorithms. Our assessment encompasses a variety of tasks, with a particular emphasis on summarization and conversation tasks involving extended contexts. Keyformer's reduction of KV cache reduces inference latency by 2.1x and improves token generation throughput by 2.4x, while preserving the model's accuracy.
Paper Structure (56 sections, 12 equations, 16 figures, 4 tables, 1 algorithm)

This paper contains 56 sections, 12 equations, 16 figures, 4 tables, 1 algorithm.

Figures (16)

  • Figure 1: (a) Inference latency normalized to sequence length of 512. We measure the $\mathsf{KV}$$\mathsf{cache}$ data movement for MPT-7B mpt model with varying sequence length (50% context + 50% text generation). (b) The $\mathsf{KV}$$\mathsf{cache}$ size and model size as sequence length varies. The studies are performed on an NVIDIA A100 GPU with a batch size of 1 and beam size of 4.
  • Figure 2: Attention block for generative inference. (a) Full attention with current token attending all previous tokens. (b) Window attention ($w=4$): Focusing on the most recent 4 tokens. (c) Dilated window attention ($w = 4$, dilation = 1). (d) $\mathsf{Keyformer}$ ($w = 2$, $k = 2$): A mix of recent window (w) and $\mathsf{key}$$\mathsf{tokens}$ (k). White color indicates no attention, while blue color indicates attention. The green color identifies the $\mathsf{key}$$\mathsf{tokens}$ and their respective attention. The values of the three consecutive token generation iterations are $t-1, t, t+1$.
  • Figure 3: (a) Default attention sparsity across different models. (b) Average attention score of three different models with 90% of attention score dedicated to 40% of the tokens called $\mathsf{key}$$\mathsf{tokens}$. (c) Accuracy comparison of three models with different attention schemes. 'Full Attention' uses the full $\mathsf{KV}$$\mathsf{cache}$ size, while 'Window Attention' and 'H$_{2}$O' use 50% of the $\mathsf{KV}$$\mathsf{cache}$ size. All models use the CNN/DailyMail cnn_dm dataset for the summarization task.
  • Figure 4: Reducing the $\mathsf{KV}$$\mathsf{cache}$ introduces a change in the distribution of attention scores. As tokens are removed, their distribution becomes uneven among the remaining cached tokens. Thus, it affects the identification of $\mathsf{key}$$\mathsf{tokens}$ according to the score function $f_{\theta}(acc\;attn)$. The figure shows this effect for the attention scores for the MPT-7B mpt model with a 50% reduction in $\mathsf{KV}$$\mathsf{cache}$.
  • Figure 5: The effect of damping on the model quality for Cerebras-GPT-6.7B model with 50% $\mathsf{KV}$$\mathsf{cache}$ reduction. Even after damping the score function to counteract the excess attention score, one does not achieve the model quality of the full attention model.
  • ...and 11 more figures