Pretraining with Token-Level Adaptive Latent Chain-of-Thought
Boyi Zeng, Yiqin Hao, He Li, Shixiang Song, Feichen Song, Zitong Wang, Siyuan Huang, Yi Xu, ZiWei He, Xinbing Wang, Zhouhan Lin
TL;DR
This work tackles data- and compute-constrained scaling of large language models by increasing per-token compute rather than model size. It introduces Adaptive Latent CoT, a one-stage pretraining framework that unrolls latent decision steps before emitting each token using Parallel Masking, a Router for probabilistic halting, and a correctness-aware adaptive loss to encourage early halting on easy tokens. Empirical results with LLaMA backbones show perplexity reductions and stronger downstream performance under comparable or reduced training FLOPs, demonstrating both training and inference efficiency gains. By adaptively allocating latent reasoning based on token difficulty, the approach aligns computation with cognitive effort, offering a practical path to more capable models without expanding parameter counts.
Abstract
Scaling large language models by increasing parameters and training data is increasingly constrained by limited high-quality corpora and rising communication costs. This work explores an alternative axis: increasing per-token computation without expanding parameters, by internalizing latent Chain-of-Thought (CoT) into pretraining. We propose Pretraining with Token-Level Adaptive Latent CoT (adaptive latent CoT), where the model generates a variable-length latent CoT trajectory before emitting each token -- allocating longer trajectories to difficult tokens and shorter (or even zero) trajectories to easy ones. Importantly, this behavior emerges naturally from one-stage pretraining on general text and reduces computation in both training and inference via token-wise adaptive halting. Experiments with Llama architectures show that adaptive latent CoT consistently improves language modeling perplexity and broad downstream accuracy, even with fewer training FLOPs than prior recurrent baselines.
