Table of Contents
Fetching ...

Uncertainty-Guided Checkpoint Selection for Reinforcement Finetuning of Large Language Models

Manh Nguyen, Dung Nguyen, Dai Do, Svetha Venkatesh, Hung Le

TL;DR

RL finetuning of large language models is notoriously unstable, making checkpoint selection critical yet costly when relying on validation data or exhaustive evaluation. The proposed Uncertainty-Guided Checkpoint Selection (UGCS) uses per-sample uncertainty to identify the hardest examples and scores checkpoints by their average reward on these cases within a short training window, reusing logs already produced during training. Across three datasets and multiple small-to-midsize LLMs, UGCS consistently finds checkpoints with stronger generalization, often outperforming validation-based and reward-only baselines, with notable gains on challenging tasks like AMC23. The method is simple, efficient, and broadly applicable, offering a practical alternative to traditional checkpointing strategies while emphasizing learning on difficult cases to improve robustness.

Abstract

Reinforcement learning (RL) finetuning is crucial to aligning large language models (LLMs), but the process is notoriously unstable and exhibits high variance across model checkpoints. In practice, selecting the best checkpoint is challenging: evaluating checkpoints on the validation set during training is computationally expensive and requires a good validation set, while relying on the final checkpoint provides no guarantee of good performance. We introduce an uncertainty-guided approach for checkpoint selection (UGCS) that avoids these pitfalls. Our method identifies hard question-answer pairs using per-sample uncertainty and ranks checkpoints by how well they handle these challenging cases. By averaging the rewards of the top-uncertain samples over a short training window, our method produces a stable and discriminative signal without additional forward passes or significant computation overhead. Experiments across three datasets and three LLMs demonstrate that it consistently identifies checkpoints with stronger generalization, outperforming traditional strategies such as relying on training or validation performance. These results highlight that models solving their hardest tasks with low uncertainty are the most reliable overall.

Uncertainty-Guided Checkpoint Selection for Reinforcement Finetuning of Large Language Models

TL;DR

RL finetuning of large language models is notoriously unstable, making checkpoint selection critical yet costly when relying on validation data or exhaustive evaluation. The proposed Uncertainty-Guided Checkpoint Selection (UGCS) uses per-sample uncertainty to identify the hardest examples and scores checkpoints by their average reward on these cases within a short training window, reusing logs already produced during training. Across three datasets and multiple small-to-midsize LLMs, UGCS consistently finds checkpoints with stronger generalization, often outperforming validation-based and reward-only baselines, with notable gains on challenging tasks like AMC23. The method is simple, efficient, and broadly applicable, offering a practical alternative to traditional checkpointing strategies while emphasizing learning on difficult cases to improve robustness.

Abstract

Reinforcement learning (RL) finetuning is crucial to aligning large language models (LLMs), but the process is notoriously unstable and exhibits high variance across model checkpoints. In practice, selecting the best checkpoint is challenging: evaluating checkpoints on the validation set during training is computationally expensive and requires a good validation set, while relying on the final checkpoint provides no guarantee of good performance. We introduce an uncertainty-guided approach for checkpoint selection (UGCS) that avoids these pitfalls. Our method identifies hard question-answer pairs using per-sample uncertainty and ranks checkpoints by how well they handle these challenging cases. By averaging the rewards of the top-uncertain samples over a short training window, our method produces a stable and discriminative signal without additional forward passes or significant computation overhead. Experiments across three datasets and three LLMs demonstrate that it consistently identifies checkpoints with stronger generalization, outperforming traditional strategies such as relying on training or validation performance. These results highlight that models solving their hardest tasks with low uncertainty are the most reliable overall.

Paper Structure

This paper contains 19 sections, 2 equations, 2 figures, 3 tables, 1 algorithm.

Figures (2)

  • Figure 1: Mean Accuracy across datasets for different difficulty metrics on Qwen2.5-0.5B and GSM8K training.
  • Figure 2: Results on GSM8K training when adjusting $p$ for various models and datasets. For each model, AVG denotes the average accuracy over four datasets, while the others (with shaded ranges) represent individual datasets with their standard deviations.