Table of Contents
Fetching ...

Rethinking Fine-Tuning when Scaling Test-Time Compute: Limiting Confidence Improves Mathematical Reasoning

Feng Chen, Allan Raventos, Nan Cheng, Surya Ganguli, Shaul Druckmann

TL;DR

The paper demonstrates that standard cross-entropy fine-tuning can misalign with pass@N test-time search, causing performance to degrade as test-time compute increases. It introduces Direct Coverage Optimization (DCO), an objective that directly maximizes the chance of finding the correct answer within N samples, and shows that this reduces overconfidence and yields Pareto-optimal tradeoffs between exploration and exploitation. Through experiments on MATH, MiniF2F, and LeanDojo theorem proving, DCO and its variants (DCOstep and DCOa) consistently improve pass@N performance, especially at large N, and demonstrate the value of co-designing training-time objectives with test-time search strategies. The work argues for end-to-end consideration of training and inference-time algorithms to unlock scalable mathematical reasoning in LLMs.

Abstract

Recent progress in large language models (LLMs) highlights the power of scaling test-time compute to achieve strong performance on complex tasks, such as mathematical reasoning and code generation. This raises a critical question: how should model training be modified to optimize performance under a subsequent test-time compute strategy and budget? To explore this, we focus on pass@N, a simple test-time strategy that searches for a correct answer in $N$ independent samples. We show, surprisingly, that training with cross-entropy (CE) loss can be ${\it misaligned}$ with pass@N in that pass@N accuracy ${\it decreases}$ with longer training. We explain the origins of this misalignment in terms of model overconfidence induced by CE, and experimentally verify our prediction of overconfidence as an impediment to scaling test-time compute via pass@N. Furthermore we suggest a principled, modified training loss that is better aligned to pass@N by limiting model confidence and rescuing pass@N test performance. Our algorithm demonstrates improved mathematical reasoning on MATH and MiniF2F benchmarks under several scenarios: (1) providing answers to math questions; and (2) proving theorems by searching over proof trees of varying shapes. Overall our work underscores the importance of co-designing two traditionally separate phases of LLM development: training-time protocols and test-time search and reasoning strategies.

Rethinking Fine-Tuning when Scaling Test-Time Compute: Limiting Confidence Improves Mathematical Reasoning

TL;DR

The paper demonstrates that standard cross-entropy fine-tuning can misalign with pass@N test-time search, causing performance to degrade as test-time compute increases. It introduces Direct Coverage Optimization (DCO), an objective that directly maximizes the chance of finding the correct answer within N samples, and shows that this reduces overconfidence and yields Pareto-optimal tradeoffs between exploration and exploitation. Through experiments on MATH, MiniF2F, and LeanDojo theorem proving, DCO and its variants (DCOstep and DCOa) consistently improve pass@N performance, especially at large N, and demonstrate the value of co-designing training-time objectives with test-time search strategies. The work argues for end-to-end consideration of training and inference-time algorithms to unlock scalable mathematical reasoning in LLMs.

Abstract

Recent progress in large language models (LLMs) highlights the power of scaling test-time compute to achieve strong performance on complex tasks, such as mathematical reasoning and code generation. This raises a critical question: how should model training be modified to optimize performance under a subsequent test-time compute strategy and budget? To explore this, we focus on pass@N, a simple test-time strategy that searches for a correct answer in independent samples. We show, surprisingly, that training with cross-entropy (CE) loss can be with pass@N in that pass@N accuracy with longer training. We explain the origins of this misalignment in terms of model overconfidence induced by CE, and experimentally verify our prediction of overconfidence as an impediment to scaling test-time compute via pass@N. Furthermore we suggest a principled, modified training loss that is better aligned to pass@N by limiting model confidence and rescuing pass@N test performance. Our algorithm demonstrates improved mathematical reasoning on MATH and MiniF2F benchmarks under several scenarios: (1) providing answers to math questions; and (2) proving theorems by searching over proof trees of varying shapes. Overall our work underscores the importance of co-designing two traditionally separate phases of LLM development: training-time protocols and test-time search and reasoning strategies.

Paper Structure

This paper contains 33 sections, 9 theorems, 28 equations, 15 figures, 6 tables.

Key Result

Lemma 4.1

$\forall N,\,N' > 0$, $\mathcal{C}^{N}$ is monotonic in $\mathcal{C}^{N'}$.

Figures (15)

  • Figure 1: A model trained with CE loss becomes overconfident in its greedy completions, which harms its pass@N coverage; our proposed DCO objective limits this overconfidence. We fine-tune a Llama-3-8B base model on the MATH dataset to produce direct answers. $\hat{y}_\text{greedy}$ is the model's greedy completion when choosing the most likely token at each sampling step. (a) The model trained with CE loss assigns progressively larger confidences $\hat{p}(\hat{y}_\text{greedy}|x)$ to its greedy completions over the course of training. (b) At the end of training, only a small portion of the model's highly confident completions are correct. This will harm the model's pass@N performance when scaling up $N$. (c) Same as (a) but shown for the DCO loss with $N'=256$. The model trained on DCO shows a much milder overconfidence effect. (d) The confidence distribution of greedy completions after four epochs with DCO for various choices of $N'$. As $N'$ increases, the model's confidence in its greedy completion is more stringently limited, directly as a consequence of the overconfidence regularizer $F$.
  • Figure 2: (a)DCO improves on CE loss for pass@N test coverage over a broad range of N and traces a Pareto-optimal frontier. We fine-tune Llama-3-8B base models on MATH to produce direct answers: one with CE loss and others using $\mathcal{L}_\text{DCO}^{N'}$ for various $N'$ (color-coded). Each curve shows pass@N coverage for a single fine-tuned model. Note that no $N'$ is optimal for all $N$. The black curve is a Pareto-optimal performance frontier traced by the max of coverage curves for DCO over all $N'$. (See \ref{['fig:DCOfrontier_256_highlight']}, which highlights $N'=256$, and \ref{['fig:DCOfrontier_70B.', 'fig:DCOfrontier_AIME']} for Llama-3-70B and AIME24 results.) (b)DCO limits overconfidence by attenuating gradients for examples on which the model is highly confident. We plot the confidence regularization factor $F$ in \ref{['eq:DCO-gradients']}. For CE loss ($N=1$), $F=1$ regardless of confidence $\hat{p}(y|x)$. For $N > 1$, $F$ decreases with $\hat{p}(y|x)$ and drops to zero around $1/N$---once confidence on an example reaches $O(1/N)$, $F$ vanishes, preventing it from further increasing. (c)Inverse confidence $N_\text{eff}$ at training time controls search-tree exploration shapes at test-time. We provide schematic proof search trees for models trained with DCOstep for small (left) and large (right) $N_\mathrm{eff}$. Circles are proof states; edges are tactics; solid edges are explored, while dashed edges are not. Small $N_\mathrm{eff}$ yields a narrow search tree that may miss correct proofs (red, dashed), while larger $N_\mathrm{eff}$ expands the tree to include the correct proof. However, if the tree becomes too wide ($N_\mathrm{eff}^k\gg N$ for a $k$-step proof), sampling the correct path becomes unlikely and pass@N coverage decreases as $\mathcal{C}^N\approx N/N_\mathrm{eff}^k$. Thus $N_\mathrm{eff}$, chosen at training time, is a powerful knob to control the tradeoff between exploitation and exploration at test time.
  • Figure 3: (a, b)Overconfidence persists in CoT fine-tuning with CE loss. We fine-tune a Llama-3-8B base model on CoT traces in the MATH dataset and plot the distribution of the estimated model confidences $\hat{p}(y^\mathrm{mode}|x)$ over samples in the test set at various points in training. (a) For CE loss, the model becomes more confident in its most likely answers as training progresses---the confidence distribution shifts to the right. This effect is milder than in the direct answer setting (\ref{['fig:overconfidence']}). (b) DCOa limits this shift of the confidence distribution over training. (c, d)Overconfidence is also present in GRPO fine-tuning and can result in a model that is overconfident and wrong.(c) The model trained with GRPO assigns progressively larger confidences to $\hat{p}(y^\mathrm{mode}|x)$ over the course of training. (d) At the end of training, only a small portion of the model‚Äôs highly confident completions are correct, which will harm the model‚Äôs pass@N performance when scaling up $N$.
  • Figure 4: Number of discarded samples as a function of training step for the DCOa experiments. The step structure reflects the model revisiting examples it has seen previously in training, where each step closely matches the start of a new epoch.
  • Figure 5: Easy data drives overconfidence and degrades performance when scaling test-time compute. For the experiments in \ref{['fig:overconfidence']}, we compute the directional derivative $\nabla_g\mathcal{L}_{\text{test}}^N$ at two points during the first training epoch: 11% (early) and 86% (late). At each stage, the gradient direction $g=-\sum_{(x_i,y_i)\in\mathcal{B}}\nabla\ell_\text{CE}(x_i,y_i)$ is evaluated on batches of previously unseen training data of each difficulty level defined in the MATH dataset. Early in training, data from all difficulty levels contribute to decreasing the test loss (left). However, later in training, easier examples (difficulty level 2) provide no further benefit, while the easiest examples (difficulty level 1) actively degrade test performance (right). The plotted $\nabla_g\mathcal{L}_{\text{test}}^N$ is an average over batches of unseen data, and we use $N=256$, corresponding to pass@256.
  • ...and 10 more figures

Theorems & Definitions (16)

  • Lemma 4.1
  • Lemma 4.2: Upper bound on max confidence
  • Lemma 4.3: Lower bound on max confidence
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Lemma A.3
  • proof
  • Corollary A.4
  • ...and 6 more