Table of Contents
Fetching ...

Efficient Pretraining Length Scaling

Bohong Wu, Shen Yan, Sijun Zhang, Jianqiao Lu, Yutao Zeng, Ya Wang, Xun Zhou

TL;DR

The paper addresses the challenge of achieving effective length scaling during pre-training without inflating KV cache size or causing large inference overhead. It introduces the Parallel Hidden Decoding (PHD) Transformer, which repeats input tokens $K$ times while retaining KV caches only for the original tokens, and discards hidden decoding tokens, enabling longer context with the same memory footprint. Two variants, PHD-SWA and PHD-CSWA, apply sliding window attention to hidden tokens (locally) to preserve dependencies and curb pre-filling latency, respectively. Empirical results on multiple benchmarks show consistent accuracy gains and manageable overhead, demonstrating that pre-training length scaling can be both effective and resource-efficient in modern LLM training. The work situates itself among sparse attention, KV cache optimization, and latent thinking transformers, offering a practical approach to scale pretraining without linearly increasing memory requirements or decoding latency.

Abstract

Recent advances in large language models have demonstrated the effectiveness of length scaling during post-training, yet its potential in pre-training remains underexplored. We present the Parallel Hidden Decoding Transformer (\textit{PHD}-Transformer), a novel framework that enables efficient length scaling during pre-training while maintaining inference efficiency. \textit{PHD}-Transformer achieves this through an innovative KV cache management strategy that distinguishes between original tokens and hidden decoding tokens. By retaining only the KV cache of original tokens for long-range dependencies while immediately discarding hidden decoding tokens after use, our approach maintains the same KV cache size as the vanilla transformer while enabling effective length scaling. To further enhance performance, we introduce two optimized variants: \textit{PHD-SWA} employs sliding window attention to preserve local dependencies, while \textit{PHD-CSWA} implements chunk-wise sliding window attention to eliminate linear growth in pre-filling time. Extensive experiments demonstrate consistent improvements across multiple benchmarks.

Efficient Pretraining Length Scaling

TL;DR

The paper addresses the challenge of achieving effective length scaling during pre-training without inflating KV cache size or causing large inference overhead. It introduces the Parallel Hidden Decoding (PHD) Transformer, which repeats input tokens times while retaining KV caches only for the original tokens, and discards hidden decoding tokens, enabling longer context with the same memory footprint. Two variants, PHD-SWA and PHD-CSWA, apply sliding window attention to hidden tokens (locally) to preserve dependencies and curb pre-filling latency, respectively. Empirical results on multiple benchmarks show consistent accuracy gains and manageable overhead, demonstrating that pre-training length scaling can be both effective and resource-efficient in modern LLM training. The work situates itself among sparse attention, KV cache optimization, and latent thinking transformers, offering a practical approach to scale pretraining without linearly increasing memory requirements or decoding latency.

Abstract

Recent advances in large language models have demonstrated the effectiveness of length scaling during post-training, yet its potential in pre-training remains underexplored. We present the Parallel Hidden Decoding Transformer (\textit{PHD}-Transformer), a novel framework that enables efficient length scaling during pre-training while maintaining inference efficiency. \textit{PHD}-Transformer achieves this through an innovative KV cache management strategy that distinguishes between original tokens and hidden decoding tokens. By retaining only the KV cache of original tokens for long-range dependencies while immediately discarding hidden decoding tokens after use, our approach maintains the same KV cache size as the vanilla transformer while enabling effective length scaling. To further enhance performance, we introduce two optimized variants: \textit{PHD-SWA} employs sliding window attention to preserve local dependencies, while \textit{PHD-CSWA} implements chunk-wise sliding window attention to eliminate linear growth in pre-filling time. Extensive experiments demonstrate consistent improvements across multiple benchmarks.

Paper Structure

This paper contains 34 sections, 3 equations, 9 figures, 2 tables.

Figures (9)

  • Figure 1: The length scaling curve on a 151M sized model. We repeat the training sequence 1/2/3/4 times on the same model architecture and train them for 100B tokens. The training loss and downstream accuracy scale robustly w.r.t. the token repeating times. For repeated training sequence, we only use the final copy of token for next token prediction loss.
  • Figure 2: Overview of the transformer block in PHD. Specifically, the input tokens are repeated multiple times fed into the transformer block simultaneously. The original tokens generate KV cache that can be attended to by all the following tokens, while the hidden decoding tokens only generate KV cache that can be attended to within the current tokens (Token 3 in the Figure). We only utilize the final copy of token for next token prediction loss.
  • Figure 3: The attention matrix in PHD. The interleaving of original tokens and hidden decoding tokens introduce very sparse attention matrix that is not device friendly. We rearrange the input sequence and split the original tokens and hidden decoding tokens into two groups. In this way, we group the un-attended attention positions in a continuous block, which is efficient for optimization.
  • Figure 4: Comparison of the attention matrices in PHD, PHD-SWA and PHD-CSWA. In this figure, we set the repeating times $K$ to 3, which means there are 2 hidden decoding tokens in each attention matrix, and set the window size $W$ to 4 and chunk size $C$ to 4.
  • Figure 5: Training curves of PHD-CSWA variants and baseline model on OLMo2-1.2B. We smooth these metrics via exponential moving average with weight 0.99 for loss and 0.7 for downstream tasks.
  • ...and 4 more figures