Table of Contents
Fetching ...

PonderLM-3: Adaptive Token-Wise Pondering with Differentiable Masking

He Li, Feichen Song, Boyi Zeng, Shixiang Song, Zhiqin John Xu, Ziwei He, Zhouhan Lin

TL;DR

PonderLM-3 provides an end-to-end differentiable and train-inference consistent framework for token-wise adaptive computation, enabling additional inference compute to be allocated where it is most useful rather than paid uniformly by every token.

Abstract

Test-time scaling has shown that allocating more additional computation at inference can improve generation quality, motivating a natural follow-up question: where should this computation be spent? Building on this insight, we introduce PonderLM-3, a pretraining framework for token-wise adaptive pondering that learns to selectively allocate additional computation under purely self-supervised objectives, built on top of the PonderLM-2 backbone. This makes additional inference computation an allocatable per-token resource, so tokens receive more computation only when it is beneficial, rather than paying a uniform extra cost. To make this allocation learnable while maintaining train-inference consistency, PonderLM-3 injects a differentiable attention mask during pretraining and pairs it with a matching hard pruning rule at inference. PonderLM-3 defines a stronger Pareto frontier: compared with existing recursive or adaptive baselines, it achieves lower pretraining perplexity at equal inference FLOPs. On downstream benchmarks, PonderLM-3 attains comparable performance to fixed-step PonderLM-2 under the same maximum number of additional computation steps, while using fewer inference FLOPs in practice. Overall, PonderLM-3 provides an end-to-end differentiable and train-inference consistent framework for token-wise adaptive computation, enabling additional inference compute to be allocated where it is most useful rather than paid uniformly by every token.

PonderLM-3: Adaptive Token-Wise Pondering with Differentiable Masking

TL;DR

PonderLM-3 provides an end-to-end differentiable and train-inference consistent framework for token-wise adaptive computation, enabling additional inference compute to be allocated where it is most useful rather than paid uniformly by every token.

Abstract

Test-time scaling has shown that allocating more additional computation at inference can improve generation quality, motivating a natural follow-up question: where should this computation be spent? Building on this insight, we introduce PonderLM-3, a pretraining framework for token-wise adaptive pondering that learns to selectively allocate additional computation under purely self-supervised objectives, built on top of the PonderLM-2 backbone. This makes additional inference computation an allocatable per-token resource, so tokens receive more computation only when it is beneficial, rather than paying a uniform extra cost. To make this allocation learnable while maintaining train-inference consistency, PonderLM-3 injects a differentiable attention mask during pretraining and pairs it with a matching hard pruning rule at inference. PonderLM-3 defines a stronger Pareto frontier: compared with existing recursive or adaptive baselines, it achieves lower pretraining perplexity at equal inference FLOPs. On downstream benchmarks, PonderLM-3 attains comparable performance to fixed-step PonderLM-2 under the same maximum number of additional computation steps, while using fewer inference FLOPs in practice. Overall, PonderLM-3 provides an end-to-end differentiable and train-inference consistent framework for token-wise adaptive computation, enabling additional inference compute to be allocated where it is most useful rather than paid uniformly by every token.
Paper Structure (42 sections, 33 equations, 8 figures, 4 tables)

This paper contains 42 sections, 33 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: From fixed compute to token-wise allocation. Standard LMs decode with one forward pass per token. PonderLM-2 applies fixed pondering steps, turning extra compute into a uniform per-token tax. PonderLM-3 makes refinement depth token-dependent, allocating additional internal updates only when they provide non-trivial marginal gains.
  • Figure 2: Token-wise adaptive pondering in PonderLM-3. (1) A lightweight router predicts the step distribution $s_{t,k}$ from $h_t^{(0)}$, and obtains a monotone mask score $w_{t,k}$ via the tail-CDF of $s$; $w_{t,k}\in[0,1]$ serves as a soft attention mask (color intensity indicates $w$). (2) the same LLM performs $K$ pondering steps; at each step, the current hidden state is inserted as a latent token at its corresponding position to form an interleaved sequence, and $w_{t,k}$ is applied in attention as a soft mask. Then we use Jacobi iterations to approximate the sequential inference dynamics, and the output uses the $s_{t,k}$-weighted fused hidden state. (3) Inference: token-by-token decoding uses the same attention masking and fusion mechanism, and early exits once $w_{t,k}<\tau$, skipping remaining pondering steps.
  • Figure 3: Performance–compute Pareto curve. The x-axis is the average number of executed additional computation steps per token (an inference FLOPs proxy). PonderLM-3 yields a better compute–quality trade-off than the baselines.
  • Figure 4: Where extra compute helps. (a) Bucket-wise $\Delta CE_i$ over executed additional computation steps. (b) Intrinsic difficulty $\ell_t$ versus executed additional computation steps.
  • Figure 5: Counterfactual compute shifting by sweeping the inference-time router logit bias $\alpha$. We report $\Delta$Loss relative to the adaptive baseline ($\alpha=0$) versus the induced average executed additional computation steps (left: easy subset; right: hard subset).
  • ...and 3 more figures