Table of Contents
Fetching ...

Ignore the KL Penalty! Boosting Exploration on Critical Tokens to Enhance RL Fine-Tuning

Jean Vassoyan, Nathanaël Beau, Roman Plaud

TL;DR

The paper investigates exploration in RL fine-tuning of LLMs using a simple arithmetic addition task and reveals that a small set of critical tokens decisively influence outcomes. It introduces a token-aware prioritized KL penalty that weights the standard KL term by the pre-trained model's token-wise certainty, aiming to boost exploration on high-impact decisions. Empirical results show that larger pre-training improves generalization, while the prioritized KL penalty enhances exploration efficiency and stabilizes performance on critical tokens across settings. The work provides a practical mechanism to improve RL fine-tuning efficiency and highlights the importance of token-level uncertainty in guiding exploration, with limitations stemming from a small model and domain scope.

Abstract

The ability to achieve long-term goals is a key challenge in the current development of large language models (LLMs). To address this, pre-trained LLMs can be fine-tuned with reinforcement learning (RL) to explore solutions that optimize a given goal. However, exploration with LLMs is difficult, as a balance has to be struck between discovering new solutions and staying close enough to the pre-trained model, so as not to degrade basic capabilities. This is typically controlled with a Kullback-Leibler (KL) penalty. In this paper, we investigate the exploration dynamics of a small language model on a simple arithmetic task. We show how varying degrees of pre-training influence exploration and demonstrate the importance of "critical tokens" which have a dramatic impact on the final outcome. Consequently, we introduce a simple modification to the KL penalty that favors exploration on critical tokens, increasing the efficiency of the RL fine-tuning stage.

Ignore the KL Penalty! Boosting Exploration on Critical Tokens to Enhance RL Fine-Tuning

TL;DR

The paper investigates exploration in RL fine-tuning of LLMs using a simple arithmetic addition task and reveals that a small set of critical tokens decisively influence outcomes. It introduces a token-aware prioritized KL penalty that weights the standard KL term by the pre-trained model's token-wise certainty, aiming to boost exploration on high-impact decisions. Empirical results show that larger pre-training improves generalization, while the prioritized KL penalty enhances exploration efficiency and stabilizes performance on critical tokens across settings. The work provides a practical mechanism to improve RL fine-tuning efficiency and highlights the importance of token-level uncertainty in guiding exploration, with limitations stemming from a small model and domain scope.

Abstract

The ability to achieve long-term goals is a key challenge in the current development of large language models (LLMs). To address this, pre-trained LLMs can be fine-tuned with reinforcement learning (RL) to explore solutions that optimize a given goal. However, exploration with LLMs is difficult, as a balance has to be struck between discovering new solutions and staying close enough to the pre-trained model, so as not to degrade basic capabilities. This is typically controlled with a Kullback-Leibler (KL) penalty. In this paper, we investigate the exploration dynamics of a small language model on a simple arithmetic task. We show how varying degrees of pre-training influence exploration and demonstrate the importance of "critical tokens" which have a dramatic impact on the final outcome. Consequently, we introduce a simple modification to the KL penalty that favors exploration on critical tokens, increasing the efficiency of the RL fine-tuning stage.

Paper Structure

This paper contains 23 sections, 2 equations, 7 figures, 5 tables.

Figures (7)

  • Figure 1: Illustration of the addition task with scratchpad, for a model pre-trained on numbers up to 3 digits. The highlighted critical tokens are decision points where the model tends to make mistakes, mainly because it is tempted to process the number as if it were shorter. This occurs when the model is faced with a number that is longer than those encountered during the pre-training stage (here, 4 digits instead of 3).
  • Figure 2: Model accuracy on addition tasks for models trained on numbers up to digit lengths $N=7, 9, 11, 13$. Results are shown for varying digit evaluation. Error bars indicate 95% confidence intervals. Full detailed results are provided in Appendix \ref{['app:detailed-pretraining-results']}.
  • Figure 3: Learning curves of multiple models pre-trained up to $N$, fine-tuned with RL on $N+2$.
  • Figure 4: Top: Learning curves of a model fine-tuned with RL on N+1=8 digits. Bottom: Probability of making the right prediction on two critical tokens. Results on more critical tokens are provided in Appendix \ref{['app:rl_experiments']}.
  • Figure 5: Output examples for addition tasks on $N+1$ digit lengths (the model is faced with numbers one notch longer than those encountered in pre-training). Each generated token is colored according to its certainty. A green color is a maximal certainty, while a red color is a minimal certainty.
  • ...and 2 more figures