Table of Contents
Fetching ...

HeadInfer: Memory-Efficient LLM Inference by Head-wise Offloading

Cheng Luo, Zefan Cai, Hanshi Sun, Jinqi Xiao, Bo Yuan, Wen Xiao, Junjie Hu, Jiawei Zhao, Beidi Chen, Anima Anandkumar

TL;DR

HeadInfer tackles the memory bottleneck of KV caches in long-context LLM inference by offloading KV state at the attention-head level. It uses adaptive head grouping, chunked prefill, and asynchronous CPU-GPU transfers (ping-pong memory) to keep only a small number of heads on the GPU at a time, preserving exact computation while drastically reducing on-GPU memory. Roofline analysis shows HeadInfer maintains compute efficiency in prefill and remains memory-bound during decoding, enabling millions of tokens of context on consumer GPUs. Empirically, it extends feasible context lengths to 1M–4M tokens across models (e.g., Llama-3-8B, Llama-3-70B) and reduces GPU memory from hundreds of gigabytes to tens of gigabytes, significantly democratizing access to long-context inference.

Abstract

Transformer-based large language models (LLMs) demonstrate impressive performance in long context generation. Extending the context length has disproportionately shifted the memory footprint of LLMs during inference to the key-value cache (KV cache). In this paper, we propose HEADINFER, which offloads the KV cache to CPU RAM while avoiding the need to fully store the KV cache for any transformer layer on the GPU. HEADINFER employs a fine-grained, head-wise offloading strategy, maintaining only selective attention heads KV cache on the GPU while computing attention output dynamically. Through roofline analysis, we demonstrate that HEADINFER maintains computational efficiency while significantly reducing memory footprint. We evaluate HEADINFER on the Llama-3-8B model with a 1-million-token sequence, reducing the GPU memory footprint of the KV cache from 128 GB to 1 GB and the total GPU memory usage from 207 GB to 17 GB, achieving a 92% reduction compared to BF16 baseline inference. Notably, HEADINFER enables 4-million-token inference with an 8B model on a single consumer GPU with 24GB memory (e.g., NVIDIA RTX 4090) without approximation methods.

HeadInfer: Memory-Efficient LLM Inference by Head-wise Offloading

TL;DR

HeadInfer tackles the memory bottleneck of KV caches in long-context LLM inference by offloading KV state at the attention-head level. It uses adaptive head grouping, chunked prefill, and asynchronous CPU-GPU transfers (ping-pong memory) to keep only a small number of heads on the GPU at a time, preserving exact computation while drastically reducing on-GPU memory. Roofline analysis shows HeadInfer maintains compute efficiency in prefill and remains memory-bound during decoding, enabling millions of tokens of context on consumer GPUs. Empirically, it extends feasible context lengths to 1M–4M tokens across models (e.g., Llama-3-8B, Llama-3-70B) and reduces GPU memory from hundreds of gigabytes to tens of gigabytes, significantly democratizing access to long-context inference.

Abstract

Transformer-based large language models (LLMs) demonstrate impressive performance in long context generation. Extending the context length has disproportionately shifted the memory footprint of LLMs during inference to the key-value cache (KV cache). In this paper, we propose HEADINFER, which offloads the KV cache to CPU RAM while avoiding the need to fully store the KV cache for any transformer layer on the GPU. HEADINFER employs a fine-grained, head-wise offloading strategy, maintaining only selective attention heads KV cache on the GPU while computing attention output dynamically. Through roofline analysis, we demonstrate that HEADINFER maintains computational efficiency while significantly reducing memory footprint. We evaluate HEADINFER on the Llama-3-8B model with a 1-million-token sequence, reducing the GPU memory footprint of the KV cache from 128 GB to 1 GB and the total GPU memory usage from 207 GB to 17 GB, achieving a 92% reduction compared to BF16 baseline inference. Notably, HEADINFER enables 4-million-token inference with an 8B model on a single consumer GPU with 24GB memory (e.g., NVIDIA RTX 4090) without approximation methods.

Paper Structure

This paper contains 32 sections, 13 equations, 11 figures, 14 tables, 1 algorithm.

Figures (11)

  • Figure 1: Estimated memory consumption of inference a Llama-3-8B model with 1 million token on a single GPU.
  • Figure 2: Demonstrations of KV cache policies in inference. Full KV cache contains two main dimensions: layer and head. Layer-wise offloads KV cache in the layer's dimension, with a cache budget of all heads per layer. HeadInfer further reduces GPU memory by adaptively reallocating cache budgets in the head's dimension, with a cache budget of one head.
  • Figure 3: Granularity of different methods. Each cube represents the entire attention process along three dimensions: Sequence (S), Layers (L), and Heads (H). Standard inference puts everything on the GPU. Chunked-prefill fetches only a part of the sequence dimension of all tokens on the GPU at a time. Layer-wise offloading fetches a subset of layers on the GPU, offloading the rest. HeadInfer introduces an even finer approach that maintains only selected heads within a layer.
  • Figure 4: HeadInfer snapshot. All parameters are stored on the GPU. Head-wise partitioned KV cache is moved across GPU and CPU with the ping-pong memory.
  • Figure 5: Flashattention in the roofline plot analysis using the RTX-4090 device setting.
  • ...and 6 more figures