LASeR: Learning to Adaptively Select Reward Models with Multi-Armed Bandits
Duy Nguyen, Archiki Prasad, Elias Stengel-Eskin, Mohit Bansal
TL;DR
LASeR tackles the problem of aligning LLMs with multiple reward models by adaptively selecting a single RM per training instance using a contextual multi-armed bandit (LinUCB). This per-instance RM choice mitigates generalization gaps and conflicting signals inherent in RM ensembles, while preserving training efficiency. Empirical results across reasoning, instruction-following, and long-context tasks show consistent accuracy gains and substantial speedups over baselines, with LASeR effectively adapting RM usage to task and data category. The approach is robust to noisy RMs, supports adding new RMs online, and offers practical benefits for scalable, multi-objective LLM alignment.
Abstract
Reward Models (RMs) are crucial to aligning large language models (LLMs), but the degree to which an RM specialized to one task (e.g. writing) generalizes to new tasks (e.g. math) is often not known a priori, often making using only one fixed RM to train LLMs suboptimal. However, optimizing LLMs with multiple RMs simultaneously can incur a prohibitively high computational cost and lead to conflicting signals from different RMs that may degrade performance. To address these challenges, we introduce LASeR (Learning to Adaptively Select Rewards), which frames reward model selection as a multi-armed bandit problem, efficiently and iteratively training LLMs using multiple RMs by selecting the most well-suited RM for each instance. On commonsense and math reasoning tasks, we show that LASeR boosts iterative LLM training, improving the absolute average accuracy of Llama-3-8B over three datasets by 2.67% over an ensemble of RM scores while also showing superior efficiency (e.g., a 2x speedup). Moreover, on WildChat (open-ended instruction-following tasks), LASeR leads to a 72.69% AlpacaEval win rate over the RM score ensemble baseline. Extending to long-context generation, LASeR improves by 2.96 F1 points (avg.) on single-document QA tasks and 2.97 F1 points on few-shot learning over the RM score ensemble baseline with best-of-n sampling.
