Table of Contents
Fetching ...

A Little Goes a Long Way: Efficient Long Context Training and Inference with Partial Contexts

Suyu Ge, Xihui Lin, Yunan Zhang, Jiawei Han, Hao Peng

TL;DR

This paper tackles the high cost of extending LLM context lengths by integrating length extension with a GPU-friendly KV-cache reduction architecture. It introduces LongGen, a hybrid transformer that keeps $1/3$ full-attention layers in the middle and applies KV-reduced sparse attention in the ends, paired with two static strategies (Attention Sink and Block Sparse) and custom kernels to cut training FLOPs and KV-cache usage. Trained on a lightweight dataset of $5$B tokens to extend context from $4K$ to $128K$ on Llama-2 bases, LongGen delivers $1.55\times$ training speedup and $62\%$ KV-cache memory reduction, plus $1.67\times$ prefilling and $1.41\times$ decoding speedups, while maintaining strong long-context performance on retrieval and reasoning benchmarks. The results show that architecture-aware short-training and GPU-friendly sparsity can realize practical, scalable long-context LLMs, with promising implications for larger models and real-time serving.

Abstract

Training and serving long-context large language models (LLMs) incurs substantial overhead. To address this, two critical steps are often required: a pretrained LLM typically undergoes a separate stage for context length extension by training on long-context data, followed by architectural modifications to reduce the overhead of KV cache during serving. This paper argues that integrating length extension with a GPU-friendly KV cache reduction architecture not only reduces training overhead during length extension, but also achieves better long-context performance. This leads to our proposed LongGen, which finetunes a pretrained LLM into an efficient architecture during length extension. LongGen builds on three key insights: (1) Sparse attention patterns, such as window attention (attending to recent tokens), attention sink (initial ones), and blockwise sparse attention (strided token blocks) are well-suited for building efficient long-context models, primarily due to their GPU-friendly memory access patterns, enabling efficiency gains not just theoretically but in practice as well. (2) It is essential for the model to have direct access to all tokens. A hybrid architecture with 1/3 full attention layers and 2/3 efficient ones achieves a balanced trade-off between efficiency and long-context performance. (3) Lightweight training on 5B long-context data is sufficient to extend the hybrid model's context length from 4K to 128K. We evaluate LongGen on both Llama-2 7B and Llama-2 70B, demonstrating its effectiveness across different scales. During training with 128K-long contexts, LongGen achieves 1.55x training speedup and reduces wall-clock time by 36%, compared to a full-attention baseline. During inference, LongGen reduces KV cache memory by 62%, achieving 1.67x prefilling speedup and 1.41x decoding speedup.

A Little Goes a Long Way: Efficient Long Context Training and Inference with Partial Contexts

TL;DR

This paper tackles the high cost of extending LLM context lengths by integrating length extension with a GPU-friendly KV-cache reduction architecture. It introduces LongGen, a hybrid transformer that keeps full-attention layers in the middle and applies KV-reduced sparse attention in the ends, paired with two static strategies (Attention Sink and Block Sparse) and custom kernels to cut training FLOPs and KV-cache usage. Trained on a lightweight dataset of B tokens to extend context from to on Llama-2 bases, LongGen delivers training speedup and KV-cache memory reduction, plus prefilling and decoding speedups, while maintaining strong long-context performance on retrieval and reasoning benchmarks. The results show that architecture-aware short-training and GPU-friendly sparsity can realize practical, scalable long-context LLMs, with promising implications for larger models and real-time serving.

Abstract

Training and serving long-context large language models (LLMs) incurs substantial overhead. To address this, two critical steps are often required: a pretrained LLM typically undergoes a separate stage for context length extension by training on long-context data, followed by architectural modifications to reduce the overhead of KV cache during serving. This paper argues that integrating length extension with a GPU-friendly KV cache reduction architecture not only reduces training overhead during length extension, but also achieves better long-context performance. This leads to our proposed LongGen, which finetunes a pretrained LLM into an efficient architecture during length extension. LongGen builds on three key insights: (1) Sparse attention patterns, such as window attention (attending to recent tokens), attention sink (initial ones), and blockwise sparse attention (strided token blocks) are well-suited for building efficient long-context models, primarily due to their GPU-friendly memory access patterns, enabling efficiency gains not just theoretically but in practice as well. (2) It is essential for the model to have direct access to all tokens. A hybrid architecture with 1/3 full attention layers and 2/3 efficient ones achieves a balanced trade-off between efficiency and long-context performance. (3) Lightweight training on 5B long-context data is sufficient to extend the hybrid model's context length from 4K to 128K. We evaluate LongGen on both Llama-2 7B and Llama-2 70B, demonstrating its effectiveness across different scales. During training with 128K-long contexts, LongGen achieves 1.55x training speedup and reduces wall-clock time by 36%, compared to a full-attention baseline. During inference, LongGen reduces KV cache memory by 62%, achieving 1.67x prefilling speedup and 1.41x decoding speedup.
Paper Structure (28 sections, 3 figures, 6 tables)

This paper contains 28 sections, 3 figures, 6 tables.

Figures (3)

  • Figure 1: Overview of LongGen. Left: It uses a hybrid architecture, and applies KV-reduced attention in 2/3 layers at the top and bottom, while keeping the middle 1/3 layers full attention. Right: Two KV-reduced attention variants are explored.
  • Figure 2: Training and inference efficiency under different sparsity levels. Left: Training wall-clock speedup. Mid: KV memory reduction. Right: Inference speedup. We compare training wall-clock time with FlashAttention and benchmark inference on vLLM. All results are measured on Llama2-7B, which consists of 32 layers in total. "1/7", "1/5", "1/3 Full" and "All Full" indicate using 5, 7, 12, and 32 full layers, respectively.
  • Figure 3: Inference time KV cache reduction methods fail on long context.