Table of Contents
Fetching ...

Pretraining with Token-Level Adaptive Latent Chain-of-Thought

Boyi Zeng, Yiqin Hao, He Li, Shixiang Song, Feichen Song, Zitong Wang, Siyuan Huang, Yi Xu, ZiWei He, Xinbing Wang, Zhouhan Lin

TL;DR

This work tackles data- and compute-constrained scaling of large language models by increasing per-token compute rather than model size. It introduces Adaptive Latent CoT, a one-stage pretraining framework that unrolls latent decision steps before emitting each token using Parallel Masking, a Router for probabilistic halting, and a correctness-aware adaptive loss to encourage early halting on easy tokens. Empirical results with LLaMA backbones show perplexity reductions and stronger downstream performance under comparable or reduced training FLOPs, demonstrating both training and inference efficiency gains. By adaptively allocating latent reasoning based on token difficulty, the approach aligns computation with cognitive effort, offering a practical path to more capable models without expanding parameter counts.

Abstract

Scaling large language models by increasing parameters and training data is increasingly constrained by limited high-quality corpora and rising communication costs. This work explores an alternative axis: increasing per-token computation without expanding parameters, by internalizing latent Chain-of-Thought (CoT) into pretraining. We propose Pretraining with Token-Level Adaptive Latent CoT (adaptive latent CoT), where the model generates a variable-length latent CoT trajectory before emitting each token -- allocating longer trajectories to difficult tokens and shorter (or even zero) trajectories to easy ones. Importantly, this behavior emerges naturally from one-stage pretraining on general text and reduces computation in both training and inference via token-wise adaptive halting. Experiments with Llama architectures show that adaptive latent CoT consistently improves language modeling perplexity and broad downstream accuracy, even with fewer training FLOPs than prior recurrent baselines.

Pretraining with Token-Level Adaptive Latent Chain-of-Thought

TL;DR

This work tackles data- and compute-constrained scaling of large language models by increasing per-token compute rather than model size. It introduces Adaptive Latent CoT, a one-stage pretraining framework that unrolls latent decision steps before emitting each token using Parallel Masking, a Router for probabilistic halting, and a correctness-aware adaptive loss to encourage early halting on easy tokens. Empirical results with LLaMA backbones show perplexity reductions and stronger downstream performance under comparable or reduced training FLOPs, demonstrating both training and inference efficiency gains. By adaptively allocating latent reasoning based on token difficulty, the approach aligns computation with cognitive effort, offering a practical path to more capable models without expanding parameter counts.

Abstract

Scaling large language models by increasing parameters and training data is increasingly constrained by limited high-quality corpora and rising communication costs. This work explores an alternative axis: increasing per-token computation without expanding parameters, by internalizing latent Chain-of-Thought (CoT) into pretraining. We propose Pretraining with Token-Level Adaptive Latent CoT (adaptive latent CoT), where the model generates a variable-length latent CoT trajectory before emitting each token -- allocating longer trajectories to difficult tokens and shorter (or even zero) trajectories to easy ones. Importantly, this behavior emerges naturally from one-stage pretraining on general text and reduces computation in both training and inference via token-wise adaptive halting. Experiments with Llama architectures show that adaptive latent CoT consistently improves language modeling perplexity and broad downstream accuracy, even with fewer training FLOPs than prior recurrent baselines.
Paper Structure (38 sections, 11 equations, 9 figures, 2 tables)

This paper contains 38 sections, 11 equations, 9 figures, 2 tables.

Figures (9)

  • Figure 1: Left: Standard latent-CoT induces a strict sequential chain across both sequence length $L$ and latent depth $K$ (highlighted in red): later tokens (e.g., $x_2$) can depend on deeper latent states of earlier positions. Right: Enforcing strict 2D causality with an attention mask that allows $(t_i,k_i)\!\rightarrow\!(t_j,k_j)$ only if $t_j \le t_i$ and $k_j \le k_i$ blocks cross-token dependencies on future latent steps (shown by $\times$), making the process sequential only in $k$ and enabling parallel computation over all $t$ at each step.
  • Figure 2: Motivation for adaptive computation and the correctness-aware adaptive loss. We probe a latent-CoT model without adaptive computation on 0.34B tokens. At each latent step $k$, we bucket tokens by their current target-token probability $p_{\mathrm{target}}^{(k)}$ (x-axis), and measure the average improvement in the target probability brought by the next latent step, $p_{\mathrm{target}}^{(k+1)} - p_{\mathrm{target}}^{(k)}$ (y-axis). While extra latent computation substantially improves low-confidence tokens, the gains diminish as $p_{\mathrm{target}}^{(k)}$ increases, and can become negative for already confident tokens (negative bars), suggesting unnecessary or even harmful computation.
  • Figure 3: Inference: token-wise adaptive latent CoT via reach probability. Left: At decoding position $t$, the model iterates latent steps to produce $z_t^{(1)},z_t^{(2)},\ldots$ and updates the reach probability, stopping when the next-step reach probability $p_{\mathrm{reach},t}^{(k+1)}$ falls below the threshold $\tau$. Easy tokens ($x_2,x_3$) stop earlier (shorter latent CoT), while harder tokens ($x_1$) continue longer. The final representation is obtained by $p_{\mathrm{exit}}$-weighted mixing of the executed latent states, and is fed to the LM head for prediction. Right: The Router takes each latent hidden state $z_t^{(k)}$ and outputs a gate $g_t^{(k)}$ (the conditional continuation probability), which defines the probability flow: $p_{\mathrm{reach}}^{k+1}=p_{\mathrm{reach}}^{k} g^{k}$ and $p_{\mathrm{exit}}^{k}=p_{\mathrm{reach}}^{k}(1-g^{k})$.
  • Figure 4: A training example with unrolled latent steps, pruning, and KV-cache reuse. We illustrate one sequence $(x_1,x_2,x_3)$ with a latent budget $K_{\max}=4$. With the parallel attention mask, we unroll computation along latent steps and, at each step, compute all active tokens in parallel. At step 1, the model produces $(z_1^{(1)},z_2^{(1)},z_3^{(1)})$ and the Router outputs gates (conditional continuation probabilities) for each token, which update $p_{\mathrm{reach}}$ and determine whether a token remains active. In the figure, $x_1,x_2$ keep sufficient reach probability to proceed (green checks), while $x_3$ is pruned (red cross). At step 2, only $x_1,x_2$ remain; the updated reach probability prunes $x_2$, so steps 3--4 are executed only for $x_1$. Across latent steps, we reuse the KV cache so later steps reuse cached attention context, avoiding redundant computation. Since pruning shrinks the active token set as $k$ increases, the overall training FLOPs decrease. For supervision, we compute per-position LM losses ($\mathcal{L}_1,\mathcal{L}_2,\mathcal{L}_3$) on the final representation $z_t^{\mathrm{f}}$ obtained by $p_{\mathrm{exit}}$-weighted mixing of the executed latent states, and add a correctness-aware adaptive loss (proportional to $p_{\mathrm{target}}$) to penalize continuing when the model already assigns high probability to the ground-truth token, the $p_{\mathrm{exit}}$-weighted mixing and $\mathcal{L}_{\mathrm{adaptive}}$ provide learning signals to both the LM and the Router.
  • Figure 5: Language modeling perplexity (PPL $\downarrow$). Validation perplexity on The Pile, WikiText, and LAMBADA (OpenAI and standard splits). Our method consistently achieves the lowest perplexity across all datasets while using fewer training FLOPs.
  • ...and 4 more figures