Optimizing Language Models for Inference Time Objectives using Reinforcement Learning
Yunhao Tang, Kunhao Zheng, Gabriel Synnaeve, Rémi Munos
TL;DR
The paper tackles the problem of aligning language-model training with downstream inference-time strategies by explicitly optimizing for inference-time objectives such as pass@$k$ and majority voting via an online reinforcement-learning framework. It derives unbiased and biased gradient estimators using a leave-one-out variance-reduction approach that couples $k$ samples through the aggregation function $f$, enabling practical credit assignment for inference-time behavior. Through experiments on mathematical reasoning and code generation, the authors show that $k$-sample objective variants can yield significant gains in inference-time performance, with trade-offs in variance, KL-divergence, and generalization across model sizes and datasets. The work suggests practical paths to deploy inference-time aware training, including combinations with PPO and insights on how results scale with $k$ and model capacity, impacting real-world deployments where latency and accuracy under collaborative inference-time strategies matter.
Abstract
In this work, we investigate the merits of explicitly optimizing for inference time algorithmic performance during model training. We show how optimizing for inference time performance can improve overall model efficacy. We consider generic inference time objectives with $k$ samples, with a focus on pass@$k$ and majority voting as two main applications. With language model training on reasoning datasets, we showcase the performance trade-off enabled by training with such objectives. When training on code generation tasks, we show that the approach significantly improves pass@$k$ objectives compared to the baseline method.
