Hydragen: High-Throughput LLM Inference with Shared Prefixes
Jordan Juravsky, Bradley Brown, Ryan Ehrlich, Daniel Y. Fu, Christopher Ré, Azalia Mirhoseini
TL;DR
Hydragen delivers an exact, hardware-aware approach to attention when batched sequences share a common prefix. By decomposing attention into shared-prefix and unique-suffix components and performing inter-sequence batching for the prefix, it replaces many matrix-vector products with fewer matrix-matrix products, enabling tensor-core utilization. Empirical results show up to 32x end-to-end throughput gains over strong baselines, with modest sensitivity to prefix length and suffix size and strong gains in long-context and hierarchical-sharing scenarios. The method generalizes to tree-like sharing patterns and is implemented in PyTorch without custom kernels, suggesting broad portability and practical impact for large-scale LLM inference environments.
Abstract
Transformer-based large language models (LLMs) are now deployed to hundreds of millions of users. LLM inference is commonly performed on batches of sequences that share a prefix, such as few-shot examples or a chatbot system prompt. Decoding in this large-batch setting can be bottlenecked by the attention operation, which reads large key-value (KV) caches from memory and computes inefficient matrix-vector products for every sequence in the batch. In this work, we introduce Hydragen, a hardware-aware exact implementation of attention with shared prefixes. Hydragen computes attention over the shared prefix and unique suffixes separately. This decomposition enables efficient prefix attention by batching queries together across sequences, reducing redundant memory reads and enabling the use of hardware-friendly matrix multiplications. Our method can improve end-to-end CodeLlama-13b throughput by up to 32x against competitive baselines, with speedup growing with the batch size and shared prefix length. Hydragen also enables the use of very long shared contexts: with a large batch size, increasing the prefix length from 1K to 16K tokens decreases Hydragen throughput by less than 15%, while the throughput of baselines drops by over 90%. Hydragen generalizes beyond simple prefix-suffix decomposition and can be applied to tree-based prompt sharing patterns, allowing us to further reduce inference time on competitive programming problems by 55%.
