Table of Contents
Fetching ...

Optimizing Language Models for Inference Time Objectives using Reinforcement Learning

Yunhao Tang, Kunhao Zheng, Gabriel Synnaeve, Rémi Munos

TL;DR

The paper tackles the problem of aligning language-model training with downstream inference-time strategies by explicitly optimizing for inference-time objectives such as pass@$k$ and majority voting via an online reinforcement-learning framework. It derives unbiased and biased gradient estimators using a leave-one-out variance-reduction approach that couples $k$ samples through the aggregation function $f$, enabling practical credit assignment for inference-time behavior. Through experiments on mathematical reasoning and code generation, the authors show that $k$-sample objective variants can yield significant gains in inference-time performance, with trade-offs in variance, KL-divergence, and generalization across model sizes and datasets. The work suggests practical paths to deploy inference-time aware training, including combinations with PPO and insights on how results scale with $k$ and model capacity, impacting real-world deployments where latency and accuracy under collaborative inference-time strategies matter.

Abstract

In this work, we investigate the merits of explicitly optimizing for inference time algorithmic performance during model training. We show how optimizing for inference time performance can improve overall model efficacy. We consider generic inference time objectives with $k$ samples, with a focus on pass@$k$ and majority voting as two main applications. With language model training on reasoning datasets, we showcase the performance trade-off enabled by training with such objectives. When training on code generation tasks, we show that the approach significantly improves pass@$k$ objectives compared to the baseline method.

Optimizing Language Models for Inference Time Objectives using Reinforcement Learning

TL;DR

The paper tackles the problem of aligning language-model training with downstream inference-time strategies by explicitly optimizing for inference-time objectives such as pass@ and majority voting via an online reinforcement-learning framework. It derives unbiased and biased gradient estimators using a leave-one-out variance-reduction approach that couples samples through the aggregation function , enabling practical credit assignment for inference-time behavior. Through experiments on mathematical reasoning and code generation, the authors show that -sample objective variants can yield significant gains in inference-time performance, with trade-offs in variance, KL-divergence, and generalization across model sizes and datasets. The work suggests practical paths to deploy inference-time aware training, including combinations with PPO and insights on how results scale with and model capacity, impacting real-world deployments where latency and accuracy under collaborative inference-time strategies matter.

Abstract

In this work, we investigate the merits of explicitly optimizing for inference time algorithmic performance during model training. We show how optimizing for inference time performance can improve overall model efficacy. We consider generic inference time objectives with samples, with a focus on pass@ and majority voting as two main applications. With language model training on reasoning datasets, we showcase the performance trade-off enabled by training with such objectives. When training on code generation tasks, we show that the approach significantly improves pass@ objectives compared to the baseline method.

Paper Structure

This paper contains 53 sections, 2 theorems, 24 equations, 16 figures, 3 tables, 1 algorithm.

Key Result

Lemma 1

(Unbiased leave-one-out gradient estimate) The gradient estimate with the leave-one-out control variate in Eqn eq:loo is unbiased.

Figures (16)

  • Figure 1: Comparison of different gradient estimates in a bandit case. We set up a bandit problem with $|\mathcal{Y}|=100$ possible actions and each reward $r(\mathbf{y})$ is a deterministic scalar sampled from unit Gaussian. We compare three algorithmic variants: the mean policy gradient, the pass@$k$ policy gradient and its biased variant. All algorithms apply $k=4$ samples per update with learning rate $\eta=1.0$. Overall, we see that the baseline gradient makes the fastest improvement on the mean performance, when graphed against the learning steps (left plot); however, it is generally less KL-efficient than other $k$-sample alternatives (middle plot). When measuring the pass@$k$ performance, the $k$-sample gradient estimates lead to significantly faster improvements (right plot).
  • Figure 2: MATH training pass@$k$ 8B model. We compare three baselines: regular mean policy gradient algorithm and two variants of pass@$k$ policy gradient algorithms (unbiased and biased). We split the performance across MATH difficulty level and report the mean performance and pass@$4$ performance over time. We observe that as training progresses, pass@$k$ policy gradient algorithms seem to display a slight advantage over the baseline algorithm.
  • Figure 3: Ablation with number of samples $k$. We vary the number of samples $k$ for each gradient update for the pass@$k$ objective. We observe a more efficiency gains for the pass@$k$ gradient estimates compared to the policy gradient baseline. Importantly, note that as $k$ varies, the pass@$k$ algorithm changes its objectives.
  • Figure 4: HARP training pass@$k$ 70B model. We observe that the regular policy gradient estimate improves over the pass@$k$ variants for the mean performance metric, while under-performing on the pass@$k$ objective. Such a trade-off is less significant for the MATH dataset, where we speculate that the 70B model is too powerful and learning signals are too sparse to make a difference.
  • Figure 5: Code generation task evaluation performance. We observe a clear trade-off between pass@1 and pass@8 on CodeContests and TACO using Llama 3.1 70B - mean policy gradient achieves the best mean (pass@1) performance, while pass@$k$ gradient variants clearly achieve much better performance fo the pass@$k$ performance.
  • ...and 11 more figures

Theorems & Definitions (4)

  • Lemma 1
  • proof
  • Lemma 2
  • proof