How Reinforcement Learning After Next-Token Prediction Facilitates Learning
Nikolaos Tsilivis, Eran Malach, Karen Ullrich, Julia Kempe
TL;DR
The paper investigates why reinforcement learning after next-token prediction helps learning in large language models when trained on data that mix short and long chain-of-thought demonstrations. Through a parity-focused framework and theoretical analysis with autoregressive linear models, it identifies a critical threshold near p_cot ≈ 1/3: pre-training alone fails to generalize when long demonstrations are rare, but post-training with RL can achieve generalization efficiently as long as long demonstrations are not exponentially scarce. Empirically, the authors demonstrate the phenomena across transformers trained from scratch on parity data and deeper reasoning tasks (number multiplication, GSM8K, MATH) using GPT-2, Mistral, and Llama-series models, showing RL boosts accuracy and induces longer, more informative outputs. The findings provide a principled account for the practical benefits of RL in LLMs, emphasize the role of chain-of-thought data availability, and reveal length-increase as an optimization-driven learning signal. Overall, the work combines theory and extensive experiments to explain how RL after next-token prediction can dramatically improve learning efficiency for challenging reasoning tasks.
Abstract
Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of $d$ bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of $d$ bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension $d$. Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.
