Table of Contents
Fetching ...

Teaching Pretrained Language Models to Think Deeper with Retrofitted Recurrence

Sean McLeish, Ang Li, John Kirchenbauer, Dayal Singh Kalra, Brian R. Bartoldson, Bhavya Kailkhura, Avi Schwarzschild, Jonas Geiping, Tom Goldstein, Micah Goldblum

TL;DR

The paper addresses the challenge of decoupling train-time compute from test-time compute in language models by retrofitting depth recurrence onto pretrained transformers. It introduces a prelude–recurrent–coda architecture and a curriculum that gradually increases recurrence depth, enabling deeper reasoning with efficient training. Key contributions include demonstrating pretrained-weight initialization benefits, recurrence scheduling to reduce training cost, and effective data mixtures and healing to preserve language modeling while boosting math reasoning on TinyLlama, OLMo, and Llama models. The results show that depth-recurrent models can achieve higher GSM8K and MATH performance under the same training FLOPs with fewer trainable parameters, highlighting a practical route to scalable, reasoning-focused language models and offering insights for future adaptive-compute architectures.

Abstract

Recent advances in depth-recurrent language models show that recurrence can decouple train-time compute and parameter count from test-time compute. In this work, we study how to convert existing pretrained non-recurrent language models into depth-recurrent models. We find that using a curriculum of recurrences to increase the effective depth of the model over the course of training preserves performance while reducing total computational cost. In our experiments, on mathematics, we observe that converting pretrained models to recurrent ones results in better performance at a given compute budget than simply post-training the original non-recurrent language model.

Teaching Pretrained Language Models to Think Deeper with Retrofitted Recurrence

TL;DR

The paper addresses the challenge of decoupling train-time compute from test-time compute in language models by retrofitting depth recurrence onto pretrained transformers. It introduces a prelude–recurrent–coda architecture and a curriculum that gradually increases recurrence depth, enabling deeper reasoning with efficient training. Key contributions include demonstrating pretrained-weight initialization benefits, recurrence scheduling to reduce training cost, and effective data mixtures and healing to preserve language modeling while boosting math reasoning on TinyLlama, OLMo, and Llama models. The results show that depth-recurrent models can achieve higher GSM8K and MATH performance under the same training FLOPs with fewer trainable parameters, highlighting a practical route to scalable, reasoning-focused language models and offering insights for future adaptive-compute architectures.

Abstract

Recent advances in depth-recurrent language models show that recurrence can decouple train-time compute and parameter count from test-time compute. In this work, we study how to convert existing pretrained non-recurrent language models into depth-recurrent models. We find that using a curriculum of recurrences to increase the effective depth of the model over the course of training preserves performance while reducing total computational cost. In our experiments, on mathematics, we observe that converting pretrained models to recurrent ones results in better performance at a given compute budget than simply post-training the original non-recurrent language model.

Paper Structure

This paper contains 25 sections, 1 equation, 43 figures, 9 tables.

Figures (43)

  • Figure 1: We take layers from pretrained language models and recur a core block. We take early layers to form the prelude and later layers to form the recurrent block and coda, removing the layers in between. After each recurrence, we concatenate the output of the prelude with the output of the recurrent block (or random noise at time zero) and apply a linear adapter.
  • Figure 2: Initializing from pretrained Llama layers gives a significant advantage in loss and benchmark accuracy.Left: Loss over training step for $120$ billion tokens for models initialized from Llama-3.2-1B layers and randomly takase2023spike. Although starting higher, the model initialized from Llama weights achieves lower losses consistently than the model initialized randomly. Right: Zero shot accuracy on Hellaswag zellers2019hellaswag over training step for recurrences $[1,2,4,8,16,32]$. We see the Llama based model (blue) achieves higher accuracy quicker and leverages recurrence effectively from early training steps. We record accuracy over recurrence for a suite of language modeling benchmarks in Appendix \ref{['app-tab:random_init_benchmarks_full']}.
  • Figure 3: Scheduling the mean of the depth distribution is efficient in terms of both data and compute. We report validation loss over multiple recurrent depths in terms on steps (i.e. data) on the left and in terms of FLOPs on the right. We see that linearly scheduling the number of recurrences up to the final mean $(32)$ over a long period of training decreases the validation loss, hence the curriculum is both data and compute efficient. Alternative length curricula and more test recurrent depths are shown in Appendix \ref{['app-fig:schedule-mean-n-all']}.
  • Figure 4: Muon improves over AdamW when training recurrent models.Left: Loss vs. step for multiple training runs on the same data order with different optimizers, using a learning rate of $5e^{-5}$ for AdamW and $0.001$ for Muon. Muon is the most stable and achieves the lowest loss for recurrent models. Note, the AdamW line ends early as the loss spikes and becomes NaN. Right: Loss (smoothed over $50$ steps) vs. step for AdamW and Muon. For the non-recurrent TinyLlama model there is minimal difference between optimizers.
  • Figure 5: Recurrence improves reasoning on GSM8K for TinyLlama, even when controlling for FLOPs. We train $(4,8,4)$ and non-recurrent models for approximately $50$ billion tokens of Nemotron-CC-Math-v1 data. Left: We plot accuracy over the number of FLOPs used during training. We see that recurrent models, trained with scheduling, can efficiently outperform the non-recurrent baseline. Right: We plot accuracy over the number of recurrences for inference. We see the recurrent models are competitive with the fixed depth baseline and can outperform it by using more recurrences and therefore more FLOPs. We plot each individual models accuracy over training step and recurrence in full in \ref{['app-subsubsec:retrofit-tinyllama']}, including for training recurrence $8$ and $32$. Evaluations on the final checkpoint over tasks shown in \ref{['tab:data-mix']} are in Appendix \ref{['app-tab:tinyllama-all-evals']}. We also provide identical experiments for OLMo and Llama in \ref{['app-subsec:retrofit']}.
  • ...and 38 more figures