Efficient Pretraining Length Scaling
Bohong Wu, Shen Yan, Sijun Zhang, Jianqiao Lu, Yutao Zeng, Ya Wang, Xun Zhou
TL;DR
The paper addresses the challenge of achieving effective length scaling during pre-training without inflating KV cache size or causing large inference overhead. It introduces the Parallel Hidden Decoding (PHD) Transformer, which repeats input tokens $K$ times while retaining KV caches only for the original tokens, and discards hidden decoding tokens, enabling longer context with the same memory footprint. Two variants, PHD-SWA and PHD-CSWA, apply sliding window attention to hidden tokens (locally) to preserve dependencies and curb pre-filling latency, respectively. Empirical results on multiple benchmarks show consistent accuracy gains and manageable overhead, demonstrating that pre-training length scaling can be both effective and resource-efficient in modern LLM training. The work situates itself among sparse attention, KV cache optimization, and latent thinking transformers, offering a practical approach to scale pretraining without linearly increasing memory requirements or decoding latency.
Abstract
Recent advances in large language models have demonstrated the effectiveness of length scaling during post-training, yet its potential in pre-training remains underexplored. We present the Parallel Hidden Decoding Transformer (\textit{PHD}-Transformer), a novel framework that enables efficient length scaling during pre-training while maintaining inference efficiency. \textit{PHD}-Transformer achieves this through an innovative KV cache management strategy that distinguishes between original tokens and hidden decoding tokens. By retaining only the KV cache of original tokens for long-range dependencies while immediately discarding hidden decoding tokens after use, our approach maintains the same KV cache size as the vanilla transformer while enabling effective length scaling. To further enhance performance, we introduce two optimized variants: \textit{PHD-SWA} employs sliding window attention to preserve local dependencies, while \textit{PHD-CSWA} implements chunk-wise sliding window attention to eliminate linear growth in pre-filling time. Extensive experiments demonstrate consistent improvements across multiple benchmarks.
