Table of Contents
Fetching ...

Layer-Condensed KV Cache for Efficient Inference of Large Language Models

Haoyi Wu, Kewei Tu

TL;DR

The paper tackles the memory bottleneck of KV caches in large language model inference by introducing Layer-Condensed KV Cache, which pairs queries from most layers with KVs from the top layer and optionally employs a sandwich of warmup layers to preserve performance. It provides a parallel training framework with gradient stopping and rapid KV convergence, enabling scalable training and efficient inference. Empirical results show up to 26x throughput gains and competitive language modeling and downstream task performance, with straightforward integration with other memory-saving techniques like StreamingLLM. The work offers a practical, orthogonal approach to boosting LLM inference efficiency with clear trade-offs between throughput and accuracy.

Abstract

Huge memory consumption has been a major bottleneck for deploying high-throughput large language models in real-world applications. In addition to the large number of parameters, the key-value (KV) cache for the attention mechanism in the transformer architecture consumes a significant amount of memory, especially when the number of layers is large for deep language models. In this paper, we propose a novel method that only computes and caches the KVs of a small number of layers, thus significantly saving memory consumption and improving inference throughput. Our experiments on large language models show that our method achieves up to 26$\times$ higher throughput than standard transformers and competitive performance in language modeling and downstream tasks. In addition, our method is orthogonal to existing transformer memory-saving techniques, so it is straightforward to integrate them with our model, achieving further improvement in inference efficiency. Our code is available at https://github.com/whyNLP/LCKV.

Layer-Condensed KV Cache for Efficient Inference of Large Language Models

TL;DR

The paper tackles the memory bottleneck of KV caches in large language model inference by introducing Layer-Condensed KV Cache, which pairs queries from most layers with KVs from the top layer and optionally employs a sandwich of warmup layers to preserve performance. It provides a parallel training framework with gradient stopping and rapid KV convergence, enabling scalable training and efficient inference. Empirical results show up to 26x throughput gains and competitive language modeling and downstream task performance, with straightforward integration with other memory-saving techniques like StreamingLLM. The work offers a practical, orthogonal approach to boosting LLM inference efficiency with clear trade-offs between throughput and accuracy.

Abstract

Huge memory consumption has been a major bottleneck for deploying high-throughput large language models in real-world applications. In addition to the large number of parameters, the key-value (KV) cache for the attention mechanism in the transformer architecture consumes a significant amount of memory, especially when the number of layers is large for deep language models. In this paper, we propose a novel method that only computes and caches the KVs of a small number of layers, thus significantly saving memory consumption and improving inference throughput. Our experiments on large language models show that our method achieves up to 26 higher throughput than standard transformers and competitive performance in language modeling and downstream tasks. In addition, our method is orthogonal to existing transformer memory-saving techniques, so it is straightforward to integrate them with our model, achieving further improvement in inference efficiency. Our code is available at https://github.com/whyNLP/LCKV.
Paper Structure (27 sections, 3 theorems, 13 figures, 10 tables)

This paper contains 27 sections, 3 theorems, 13 figures, 10 tables.

Key Result

Theorem 1

The two computation graphs are equivalent in terms of model training.

Figures (13)

  • Figure 1: Illustration of a standard transformer decoder and our model. Each node represents one layer of transformer computation of one token. Each horizontal edge $a \rightarrow b$ denotes that the queries at $b$ are paired with the KVs at $a$.
  • Figure 2: Two equivalent computation graphs for training our model with two layers. (a) Sequential over $n$ tokens. (b) Parallel over $n$ tokens with $n$ iterations. Sub-graphs with the same color (except grey) represent identical data flow. Sub-graphs in grey are unused in loss computation.
  • Figure 3: MSE of the KV before and after the $i$th iteration. The model is randomly initialized and tested with 2048 tokens.
  • Figure 4: Throughput of 7B Llama and our model w.r.t. the batch size.
  • Figure 5: Comparison of latency per token and memory consumption of StreamingLLM and our model ($w=10$) integrated with StreamingLLM w.r.t. different cache sizes.
  • ...and 8 more figures

Theorems & Definitions (4)

  • Theorem 1
  • Theorem 1
  • Lemma 1
  • proof