Layer-Condensed KV Cache for Efficient Inference of Large Language Models
Haoyi Wu, Kewei Tu
TL;DR
The paper tackles the memory bottleneck of KV caches in large language model inference by introducing Layer-Condensed KV Cache, which pairs queries from most layers with KVs from the top layer and optionally employs a sandwich of warmup layers to preserve performance. It provides a parallel training framework with gradient stopping and rapid KV convergence, enabling scalable training and efficient inference. Empirical results show up to 26x throughput gains and competitive language modeling and downstream task performance, with straightforward integration with other memory-saving techniques like StreamingLLM. The work offers a practical, orthogonal approach to boosting LLM inference efficiency with clear trade-offs between throughput and accuracy.
Abstract
Huge memory consumption has been a major bottleneck for deploying high-throughput large language models in real-world applications. In addition to the large number of parameters, the key-value (KV) cache for the attention mechanism in the transformer architecture consumes a significant amount of memory, especially when the number of layers is large for deep language models. In this paper, we propose a novel method that only computes and caches the KVs of a small number of layers, thus significantly saving memory consumption and improving inference throughput. Our experiments on large language models show that our method achieves up to 26$\times$ higher throughput than standard transformers and competitive performance in language modeling and downstream tasks. In addition, our method is orthogonal to existing transformer memory-saving techniques, so it is straightforward to integrate them with our model, achieving further improvement in inference efficiency. Our code is available at https://github.com/whyNLP/LCKV.
