Table of Contents
Fetching ...

Think, Prune, Train, Improve: Scaling Reasoning without Scaling Models

Caia Costello, Simon Guo, Anna Goldie, Azalia Mirhoseini

TL;DR

The paper addresses the challenge of improving reasoning in small-to-mid sized LLMs without relying on larger teacher models. It proposes the Think, Prune, Train (TPT) framework, where models generate their own reasoning traces, prune incorrect or suboptimal solutions using ground-truth correctness filters, and then perform supervised fine-tuning on validated outputs. Key findings show that correctness-based pruning stabilizes training and enables meaningful self-improvement through recursive fine-tuning, with substantial Pass@1 gains on GSM8K (e.g., Gemma-2B: 41.9%→57.6%; Gemma-9B: 66.4%→82.4%) and large gains for LLaMA-70B (78.6%→91%), occasionally surpassing GPT-4o on certain benchmarks. The work also reveals that simply scaling synthetic data is not universally beneficial, that pruning is crucial for data quality, and that the framework can achieve competitive results without external supervision, offering a path to scalable reasoning enhancements in smaller models.

Abstract

Large language models (LLMs) have demonstrated strong capabilities in programming and mathematical reasoning tasks, but are constrained by limited high-quality training data. Synthetic data can be leveraged to enhance fine-tuning outcomes, but several factors influence this process, including model size, synthetic data volume, pruning strategy, and number of fine-tuning rounds. We explore these axes and investigate which conditions enable model self-improvement. We introduce the Think, Prune, Train process, a scalable framework that iteratively fine-tunes models on their own reasoning traces, using ground-truth pruning to ensure high-quality training data. This approach yields improved performance: on GSM8K, Gemma2-2B achieves a Pass@1 of 57.6% (from 41.9%), Gemma2-9B reaches 82%, matching LLaMA-3.1-70B, and LLaMA-3.1-70B attains 91%, even surpassing GPT-4o, demonstrating the effectiveness of self-generated reasoning and systematic data selection for improving LLM capabilities.

Think, Prune, Train, Improve: Scaling Reasoning without Scaling Models

TL;DR

The paper addresses the challenge of improving reasoning in small-to-mid sized LLMs without relying on larger teacher models. It proposes the Think, Prune, Train (TPT) framework, where models generate their own reasoning traces, prune incorrect or suboptimal solutions using ground-truth correctness filters, and then perform supervised fine-tuning on validated outputs. Key findings show that correctness-based pruning stabilizes training and enables meaningful self-improvement through recursive fine-tuning, with substantial Pass@1 gains on GSM8K (e.g., Gemma-2B: 41.9%→57.6%; Gemma-9B: 66.4%→82.4%) and large gains for LLaMA-70B (78.6%→91%), occasionally surpassing GPT-4o on certain benchmarks. The work also reveals that simply scaling synthetic data is not universally beneficial, that pruning is crucial for data quality, and that the framework can achieve competitive results without external supervision, offering a path to scalable reasoning enhancements in smaller models.

Abstract

Large language models (LLMs) have demonstrated strong capabilities in programming and mathematical reasoning tasks, but are constrained by limited high-quality training data. Synthetic data can be leveraged to enhance fine-tuning outcomes, but several factors influence this process, including model size, synthetic data volume, pruning strategy, and number of fine-tuning rounds. We explore these axes and investigate which conditions enable model self-improvement. We introduce the Think, Prune, Train process, a scalable framework that iteratively fine-tunes models on their own reasoning traces, using ground-truth pruning to ensure high-quality training data. This approach yields improved performance: on GSM8K, Gemma2-2B achieves a Pass@1 of 57.6% (from 41.9%), Gemma2-9B reaches 82%, matching LLaMA-3.1-70B, and LLaMA-3.1-70B attains 91%, even surpassing GPT-4o, demonstrating the effectiveness of self-generated reasoning and systematic data selection for improving LLM capabilities.

Paper Structure

This paper contains 21 sections, 4 equations, 3 figures, 9 tables, 1 algorithm.

Figures (3)

  • Figure 1: Recurrent process for model training. Data generation, pruning, and supervised fine-tuning (SFT), with the arrow indicating the feedback loop to the tuned model.
  • Figure 2: Recursive training enhances GSM8K Pass@1/20 performance in Gemma models. On GSM8K Gemma2-2B model’s Pass@1(%) performance increases from 41.9% to 57.6% over four iterations of this process \ref{['recursive-algorithm']} While, the Gemma-9B model improves to a Pass@1 of 82.4%. Bars M1-4 represent the models trained though the TPT process, starting with base Gemma2-2B/9B .
  • Figure 3: Recursive Fine-Tuning Can Improve Over Open Source Models. The Gemma-2-9B model improves from 66.0% to 82.4%, while LLaMA-70B improves from 76.0% to 91.5% Pass@1, surpassing GPT-4o (82%). This figure highlights the performance gains achieved through recursive fine-tuning.