A Little Goes a Long Way: Efficient Long Context Training and Inference with Partial Contexts
Suyu Ge, Xihui Lin, Yunan Zhang, Jiawei Han, Hao Peng
TL;DR
This paper tackles the high cost of extending LLM context lengths by integrating length extension with a GPU-friendly KV-cache reduction architecture. It introduces LongGen, a hybrid transformer that keeps $1/3$ full-attention layers in the middle and applies KV-reduced sparse attention in the ends, paired with two static strategies (Attention Sink and Block Sparse) and custom kernels to cut training FLOPs and KV-cache usage. Trained on a lightweight dataset of $5$B tokens to extend context from $4K$ to $128K$ on Llama-2 bases, LongGen delivers $1.55\times$ training speedup and $62\%$ KV-cache memory reduction, plus $1.67\times$ prefilling and $1.41\times$ decoding speedups, while maintaining strong long-context performance on retrieval and reasoning benchmarks. The results show that architecture-aware short-training and GPU-friendly sparsity can realize practical, scalable long-context LLMs, with promising implications for larger models and real-time serving.
Abstract
Training and serving long-context large language models (LLMs) incurs substantial overhead. To address this, two critical steps are often required: a pretrained LLM typically undergoes a separate stage for context length extension by training on long-context data, followed by architectural modifications to reduce the overhead of KV cache during serving. This paper argues that integrating length extension with a GPU-friendly KV cache reduction architecture not only reduces training overhead during length extension, but also achieves better long-context performance. This leads to our proposed LongGen, which finetunes a pretrained LLM into an efficient architecture during length extension. LongGen builds on three key insights: (1) Sparse attention patterns, such as window attention (attending to recent tokens), attention sink (initial ones), and blockwise sparse attention (strided token blocks) are well-suited for building efficient long-context models, primarily due to their GPU-friendly memory access patterns, enabling efficiency gains not just theoretically but in practice as well. (2) It is essential for the model to have direct access to all tokens. A hybrid architecture with 1/3 full attention layers and 2/3 efficient ones achieves a balanced trade-off between efficiency and long-context performance. (3) Lightweight training on 5B long-context data is sufficient to extend the hybrid model's context length from 4K to 128K. We evaluate LongGen on both Llama-2 7B and Llama-2 70B, demonstrating its effectiveness across different scales. During training with 128K-long contexts, LongGen achieves 1.55x training speedup and reduces wall-clock time by 36%, compared to a full-attention baseline. During inference, LongGen reduces KV cache memory by 62%, achieving 1.67x prefilling speedup and 1.41x decoding speedup.
