Table of Contents
Fetching ...

AdaPonderLM: Gated Pondering Language Models with Token-Wise Adaptive Depth

Shixiang Song, He Li, Zitong Wang, Boyi Zeng, Feichen Song, Yixuan Wang, Zhiqin John Xu, Ziwei He, Zhouhan Lin

TL;DR

AdaPonderLM is a self-supervised recurrent language model that learns token-wise early exiting during pretraining without manually tuned per-token/per-layer pruning ratios, and shows the learned gates allocate more computation to high-NLL (hard) tokens, exhibiting adaptive computation time behavior in a fully self-supervised setting.

Abstract

Test-time scaling via recurrent/iterative Transformers enables large language models to spend more computation at inference, but most pretrained recurrent LMs run a fixed number of iterations, wasting compute on easy tokens and lacking token-wise adaptivity. Following the core idea of Adaptive Computation Time(ACT) and Early Exit(EE), we propose AdaPonderLM, a self-supervised recurrent language model that learns token-wise early exiting during pretraining without manually tuned per-token/per-layer pruning ratios. AdaPonderLM uses iteration-specific MLP gates with a monotonic halting mask to decide when each token stops recurring, and introduces a KV reuse mechanism that reuses cached key/value states for halted tokens, ensuring train--test consistency and practical acceleration. Across Pythia backbones from 70M to 410M (pretraining) and up to 2.8B (continued pretraining), AdaPonderLM reduces inference compute at about 10% while maintaining comparable language modeling perplexity and competitive downstream accuracy. Our analysis shows the learned gates allocate more computation to high-NLL (hard) tokens, exhibiting adaptive computation time behavior in a fully self-supervised setting. Meanwhile, under iso-FLOPs, the learned halting policy consistently outperforms fixed pruning, showing AdaPonderLM allocates compute to the right tokens rather than just reducing average depth.

AdaPonderLM: Gated Pondering Language Models with Token-Wise Adaptive Depth

TL;DR

AdaPonderLM is a self-supervised recurrent language model that learns token-wise early exiting during pretraining without manually tuned per-token/per-layer pruning ratios, and shows the learned gates allocate more computation to high-NLL (hard) tokens, exhibiting adaptive computation time behavior in a fully self-supervised setting.

Abstract

Test-time scaling via recurrent/iterative Transformers enables large language models to spend more computation at inference, but most pretrained recurrent LMs run a fixed number of iterations, wasting compute on easy tokens and lacking token-wise adaptivity. Following the core idea of Adaptive Computation Time(ACT) and Early Exit(EE), we propose AdaPonderLM, a self-supervised recurrent language model that learns token-wise early exiting during pretraining without manually tuned per-token/per-layer pruning ratios. AdaPonderLM uses iteration-specific MLP gates with a monotonic halting mask to decide when each token stops recurring, and introduces a KV reuse mechanism that reuses cached key/value states for halted tokens, ensuring train--test consistency and practical acceleration. Across Pythia backbones from 70M to 410M (pretraining) and up to 2.8B (continued pretraining), AdaPonderLM reduces inference compute at about 10% while maintaining comparable language modeling perplexity and competitive downstream accuracy. Our analysis shows the learned gates allocate more computation to high-NLL (hard) tokens, exhibiting adaptive computation time behavior in a fully self-supervised setting. Meanwhile, under iso-FLOPs, the learned halting policy consistently outperforms fixed pruning, showing AdaPonderLM allocates compute to the right tokens rather than just reducing average depth.
Paper Structure (36 sections, 14 equations, 8 figures, 4 tables, 2 algorithms)

This paper contains 36 sections, 14 equations, 8 figures, 4 tables, 2 algorithms.

Figures (8)

  • Figure 1: Illustration of the gating behavior at inference time. Across recurrent iterations, tokens that are pruned reuse their KV cache from the previous iteration, avoiding redundant computation. In this example, token B is halted by the gating MLP after the first iteration; its KV cache is then reused in the second and third iterations. Tokens D and F are halted in the second iteration and similarly reuse their KV caches thereafter.
  • Figure 2: Mechanism overview. The same Transformer is executed recurrently across iterations. At iteration $i$, an iteration-specific MLP produces gate probabilities, which update a persistent mask $m^{i}$. KV states are aligned token-wise via where: pruned tokens reuse cached KV states, ensuring training--inference consistency.
  • Figure 3: Hyperparameter Sensitivity. Loss increases monotonically with $\lambda$ (2.7425 $\rightarrow$ 2.7455) and $k$ (2.7425 $\rightarrow$ 2.7454), indicating that lower regularization strengths are preferred for minimizing loss.
  • Figure 4: Gating mechanism reduces effective compute steps by 10% without compromising representation quality.
  • Figure 5: Dynamics of pruning across recurrent iterations.Top: Mean token NLL for tokens halted at different forward steps. Bottom: Density map of the token NLL distribution (clipped for clarity).
  • ...and 3 more figures