Table of Contents
Fetching ...

LASeR: Learning to Adaptively Select Reward Models with Multi-Armed Bandits

Duy Nguyen, Archiki Prasad, Elias Stengel-Eskin, Mohit Bansal

TL;DR

LASeR tackles the problem of aligning LLMs with multiple reward models by adaptively selecting a single RM per training instance using a contextual multi-armed bandit (LinUCB). This per-instance RM choice mitigates generalization gaps and conflicting signals inherent in RM ensembles, while preserving training efficiency. Empirical results across reasoning, instruction-following, and long-context tasks show consistent accuracy gains and substantial speedups over baselines, with LASeR effectively adapting RM usage to task and data category. The approach is robust to noisy RMs, supports adding new RMs online, and offers practical benefits for scalable, multi-objective LLM alignment.

Abstract

Reward Models (RMs) are crucial to aligning large language models (LLMs), but the degree to which an RM specialized to one task (e.g. writing) generalizes to new tasks (e.g. math) is often not known a priori, often making using only one fixed RM to train LLMs suboptimal. However, optimizing LLMs with multiple RMs simultaneously can incur a prohibitively high computational cost and lead to conflicting signals from different RMs that may degrade performance. To address these challenges, we introduce LASeR (Learning to Adaptively Select Rewards), which frames reward model selection as a multi-armed bandit problem, efficiently and iteratively training LLMs using multiple RMs by selecting the most well-suited RM for each instance. On commonsense and math reasoning tasks, we show that LASeR boosts iterative LLM training, improving the absolute average accuracy of Llama-3-8B over three datasets by 2.67% over an ensemble of RM scores while also showing superior efficiency (e.g., a 2x speedup). Moreover, on WildChat (open-ended instruction-following tasks), LASeR leads to a 72.69% AlpacaEval win rate over the RM score ensemble baseline. Extending to long-context generation, LASeR improves by 2.96 F1 points (avg.) on single-document QA tasks and 2.97 F1 points on few-shot learning over the RM score ensemble baseline with best-of-n sampling.

LASeR: Learning to Adaptively Select Reward Models with Multi-Armed Bandits

TL;DR

LASeR tackles the problem of aligning LLMs with multiple reward models by adaptively selecting a single RM per training instance using a contextual multi-armed bandit (LinUCB). This per-instance RM choice mitigates generalization gaps and conflicting signals inherent in RM ensembles, while preserving training efficiency. Empirical results across reasoning, instruction-following, and long-context tasks show consistent accuracy gains and substantial speedups over baselines, with LASeR effectively adapting RM usage to task and data category. The approach is robust to noisy RMs, supports adding new RMs online, and offers practical benefits for scalable, multi-objective LLM alignment.

Abstract

Reward Models (RMs) are crucial to aligning large language models (LLMs), but the degree to which an RM specialized to one task (e.g. writing) generalizes to new tasks (e.g. math) is often not known a priori, often making using only one fixed RM to train LLMs suboptimal. However, optimizing LLMs with multiple RMs simultaneously can incur a prohibitively high computational cost and lead to conflicting signals from different RMs that may degrade performance. To address these challenges, we introduce LASeR (Learning to Adaptively Select Rewards), which frames reward model selection as a multi-armed bandit problem, efficiently and iteratively training LLMs using multiple RMs by selecting the most well-suited RM for each instance. On commonsense and math reasoning tasks, we show that LASeR boosts iterative LLM training, improving the absolute average accuracy of Llama-3-8B over three datasets by 2.67% over an ensemble of RM scores while also showing superior efficiency (e.g., a 2x speedup). Moreover, on WildChat (open-ended instruction-following tasks), LASeR leads to a 72.69% AlpacaEval win rate over the RM score ensemble baseline. Extending to long-context generation, LASeR improves by 2.96 F1 points (avg.) on single-document QA tasks and 2.97 F1 points on few-shot learning over the RM score ensemble baseline with best-of-n sampling.
Paper Structure (22 sections, 8 equations, 13 figures, 21 tables, 1 algorithm)

This paper contains 22 sections, 8 equations, 13 figures, 21 tables, 1 algorithm.

Figures (13)

  • Figure 1: Overview of LASeR. Given the query, the multi-armed bandit selects an RM depending on the underlying query and the bandit's parameters (based on the usage of each RM and the expected MAB reward). At iteration $m$, the LLM generates multiple responses that are scored based on the selected RM for that query. These responses are ranked into preference pairs, which are then used to fine-tune the model. The same training loss $\mathcal{L}^m$ is used to update the parameters of the LLM as well as the MAB for the next iteration, making the entire pipeline iterative.
  • Figure 2: Length-controlled AlpacaEval win rates comparing LASeR against baselines on WildChat instruction-following tasks using Llama-3-8B. The top row shows comparisons against single RM selection methods, while the bottom row shows comparisons against multi-RM ensemble methods.
  • Figure 3: Agreement in preference rankings between RMs on MMLU (left) and WildChat (right).
  • Figure 4: Training efficiency of LASeR vs. different baselines on GSM8K.
  • Figure 5: Reward difference distribution of chosen and rejected responses of Llama-3-8B on RewardBench.
  • ...and 8 more figures