Table of Contents
Fetching ...

How Reinforcement Learning After Next-Token Prediction Facilitates Learning

Nikolaos Tsilivis, Eran Malach, Karen Ullrich, Julia Kempe

TL;DR

The paper investigates why reinforcement learning after next-token prediction helps learning in large language models when trained on data that mix short and long chain-of-thought demonstrations. Through a parity-focused framework and theoretical analysis with autoregressive linear models, it identifies a critical threshold near p_cot ≈ 1/3: pre-training alone fails to generalize when long demonstrations are rare, but post-training with RL can achieve generalization efficiently as long as long demonstrations are not exponentially scarce. Empirically, the authors demonstrate the phenomena across transformers trained from scratch on parity data and deeper reasoning tasks (number multiplication, GSM8K, MATH) using GPT-2, Mistral, and Llama-series models, showing RL boosts accuracy and induces longer, more informative outputs. The findings provide a principled account for the practical benefits of RL in LLMs, emphasize the role of chain-of-thought data availability, and reveal length-increase as an optimization-driven learning signal. Overall, the work combines theory and extensive experiments to explain how RL after next-token prediction can dramatically improve learning efficiency for challenging reasoning tasks.

Abstract

Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of $d$ bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of $d$ bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension $d$. Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.

How Reinforcement Learning After Next-Token Prediction Facilitates Learning

TL;DR

The paper investigates why reinforcement learning after next-token prediction helps learning in large language models when trained on data that mix short and long chain-of-thought demonstrations. Through a parity-focused framework and theoretical analysis with autoregressive linear models, it identifies a critical threshold near p_cot ≈ 1/3: pre-training alone fails to generalize when long demonstrations are rare, but post-training with RL can achieve generalization efficiently as long as long demonstrations are not exponentially scarce. Empirically, the authors demonstrate the phenomena across transformers trained from scratch on parity data and deeper reasoning tasks (number multiplication, GSM8K, MATH) using GPT-2, Mistral, and Llama-series models, showing RL boosts accuracy and induces longer, more informative outputs. The findings provide a principled account for the practical benefits of RL in LLMs, emphasize the role of chain-of-thought data availability, and reveal length-increase as an optimization-driven learning signal. Overall, the work combines theory and extensive experiments to explain how RL after next-token prediction can dramatically improve learning efficiency for challenging reasoning tasks.

Abstract

Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension . Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.

Paper Structure

This paper contains 76 sections, 12 theorems, 158 equations, 41 figures, 3 algorithms.

Key Result

Theorem 1

(Pre-training, Informal) Let $d \geq 2,\, p_{\mathrm{cot}} \in(0, 3/4)$. Consider distribution $\mathcal{D}(p_{\mathrm{cot}})$, as defined in Section sec:setup, and linear hypothesis classes $\mathcal{H}_1, \mathcal{H}_{2a}, \mathcal{H}_2, \ldots, \mathcal{H}_d$ as defined earlier. Consider running

Figures (41)

  • Figure 1: Left: An illustration of our main learning setting: a mixture of long and short sequences encoding the parity of $d$ bits, along with a representation of pre-training and post-training for $d$=5. Right: Advantage of next-token prediction followed by reinforcement learning over mere next-token prediction in predicting the parity of $d$=50 bits with a transformer trained from scratch. The red line corresponds to pre-training using next-token prediction, while blue lines correspond to the same pre-training runs, each followed by GRPO (with a final token accuracy reward) from a different checkpoint. Lines correspond to median across 3 seeds. Top: Test accuracy under greedy decoding. Bottom: Median length of model response under greedy decoding. The inset figures zoom-in on the curves during post-training. Accuracy plateaus around 50% (random guess) during pre-training and then immediately grows during post-training. Also, length increases during post-training.
  • Figure 2: Pre-training of transformers on mixture of long and short sequences encoding the parity of $d$ bits. Left: Test accuracy during the course of pre-training. Center: The probability that a model's generation length is equal to the maximum length present in the training distribution (which equals $d$). Solid lines correspond to greedy decoding, while dashed lines correspond to sampling with temperature 1. Each color denotes a training distribution with a separate mixture coefficient $p_{\mathrm{cot}}$. Figure shows average and 1 standard deviation across 3 seeds. Right: Test accuracy with greedy decoding at the end of pre-training (50k iterations) versus mixture coefficient $p_{\mathrm{cot}}$. The red dashed line corresponds to the critical threshold. Each bullet is the median of 3 runs.
  • Figure 3: Post-training of transformers on mixture of long and short sequences encoding the parity of $d$ bits with various RL methods and generation temperatures ($\tau_{\mathrm{RL}}$). Left: Test accuracy during the course of post-training with greedy decoding (solid lines) and sampling with temperature 1 (dashed lines) for various RL methods. Figure shows average and 1 standard deviation across 3 seeds. Right: Length of generated response (sampled with temperature 1) during the course of a post-training run (GRPO with end-to-end reward) for 640 test inputs, after $20$k pre-training iterations. Note: The sample size $n$ of each RL iteration differs amongst the RL algorithms: $n$=64 for GRPO, REINFORCE and $n$=3,200 sequences for STaR.
  • Figure 4: Example of a 4x4 sequence, encoding 2365*4374 (digits appear in reverse order). Top row: long format with full sequence. Bottom row: short format without chain of thought.
  • Figure 5: Advantage of next-token prediction followed by reinforcement learning over mere next-token prediction in (Left) multiplying two 5-digit numbers with a GPT2 model trained from scratch, and in (Right) solving grade-school math problems with Llama 3.2 3B (base). Left: Red line: pre-training on a $\texttt{5x}\texttt{5}$ dataset with $p_{\mathrm{cot}}$=$0.25$ using next-token prediction. Blue lines: the same pre-training runs, each followed by GRPO from a different checkpoint. Top: median test accuracy under greedy decoding (3 seeds). Bottom: median average response length (median over seeds, averaged over sequences). Right: $\star$ markers: supervised fine-tuning checkpoints on GSM8k with $p_{\mathrm{cot}}$=$0.05$ (epochs 1–30). Colored curves: GRPO post-training from these checkpoints. Training sequences are reused across epochs; during post-training, multiple generations per input are counted.
  • ...and 36 more figures

Theorems & Definitions (31)

  • Theorem 1
  • Remark 1
  • Remark 2
  • Theorem 2
  • Remark 3
  • Remark 4
  • Definition 1
  • Theorem 3
  • proof
  • Remark 5
  • ...and 21 more