Think, Prune, Train, Improve: Scaling Reasoning without Scaling Models
Caia Costello, Simon Guo, Anna Goldie, Azalia Mirhoseini
TL;DR
The paper addresses the challenge of improving reasoning in small-to-mid sized LLMs without relying on larger teacher models. It proposes the Think, Prune, Train (TPT) framework, where models generate their own reasoning traces, prune incorrect or suboptimal solutions using ground-truth correctness filters, and then perform supervised fine-tuning on validated outputs. Key findings show that correctness-based pruning stabilizes training and enables meaningful self-improvement through recursive fine-tuning, with substantial Pass@1 gains on GSM8K (e.g., Gemma-2B: 41.9%→57.6%; Gemma-9B: 66.4%→82.4%) and large gains for LLaMA-70B (78.6%→91%), occasionally surpassing GPT-4o on certain benchmarks. The work also reveals that simply scaling synthetic data is not universally beneficial, that pruning is crucial for data quality, and that the framework can achieve competitive results without external supervision, offering a path to scalable reasoning enhancements in smaller models.
Abstract
Large language models (LLMs) have demonstrated strong capabilities in programming and mathematical reasoning tasks, but are constrained by limited high-quality training data. Synthetic data can be leveraged to enhance fine-tuning outcomes, but several factors influence this process, including model size, synthetic data volume, pruning strategy, and number of fine-tuning rounds. We explore these axes and investigate which conditions enable model self-improvement. We introduce the Think, Prune, Train process, a scalable framework that iteratively fine-tunes models on their own reasoning traces, using ground-truth pruning to ensure high-quality training data. This approach yields improved performance: on GSM8K, Gemma2-2B achieves a Pass@1 of 57.6% (from 41.9%), Gemma2-9B reaches 82%, matching LLaMA-3.1-70B, and LLaMA-3.1-70B attains 91%, even surpassing GPT-4o, demonstrating the effectiveness of self-generated reasoning and systematic data selection for improving LLM capabilities.
