Table of Contents
Fetching ...

Generalizing Reward Modeling for Out-of-Distribution Preference Learning

Chen Jia

TL;DR

This work tackles generalizing reward modeling for out-of-distribution preference learning (OOD PL) in LLM alignment. It introduces a gradient-based bilevel meta-learning framework where a single reward function is trained to guide policy optimization across multiple task distributions, mitigating policy drift and distribution shift via KL regularization. The outer objective optimizes preference alignment, while the inner objective performs task-specific policy fine-tuning; a convergence bound shows the method approaches a stationary point as the number of outer iterations grows and the reward factor $\beta$ increases. Empirically, the approach achieves state-of-the-art performance on controlled sentiment generation and knowledge answer generation across 20 held-out domains, with improvements in PL accuracy, RM-based rewards, and human-judgement metrics (including GPT-4 judgments). The results establish the practical value of meta-trained, general RM for robust OOD PL in real-world alignment tasks.

Abstract

Preference learning (PL) with large language models (LLMs) aims to align the LLMs' generations with human preferences. Previous work on reinforcement learning from human feedback (RLHF) has demonstrated promising results in in-distribution PL. However, due to the difficulty of obtaining human feedback, discretely training reward models for every encountered distribution is challenging. Thus, out-of-distribution (OOD) PL is practically useful for enhancing the generalization ability of LLMs with limited preference feedback. This work addresses OOD PL by optimizing a general reward model through a meta-learning approach. During meta-training, a bilevel optimization algorithm is utilized to learn a reward model capable of guiding policy learning to align with human preferences across various distributions. When encountering a test distribution, the meta-test procedure conducts regularized policy optimization using the learned reward model for PL. We theoretically demonstrate the convergence rate of the bilevel optimization algorithm under reasonable assumptions. Additionally, we conduct experiments on two text generation tasks across 20 held-out domains and outperform a variety of strong baselines across various evaluation metrics.

Generalizing Reward Modeling for Out-of-Distribution Preference Learning

TL;DR

This work tackles generalizing reward modeling for out-of-distribution preference learning (OOD PL) in LLM alignment. It introduces a gradient-based bilevel meta-learning framework where a single reward function is trained to guide policy optimization across multiple task distributions, mitigating policy drift and distribution shift via KL regularization. The outer objective optimizes preference alignment, while the inner objective performs task-specific policy fine-tuning; a convergence bound shows the method approaches a stationary point as the number of outer iterations grows and the reward factor increases. Empirically, the approach achieves state-of-the-art performance on controlled sentiment generation and knowledge answer generation across 20 held-out domains, with improvements in PL accuracy, RM-based rewards, and human-judgement metrics (including GPT-4 judgments). The results establish the practical value of meta-trained, general RM for robust OOD PL in real-world alignment tasks.

Abstract

Preference learning (PL) with large language models (LLMs) aims to align the LLMs' generations with human preferences. Previous work on reinforcement learning from human feedback (RLHF) has demonstrated promising results in in-distribution PL. However, due to the difficulty of obtaining human feedback, discretely training reward models for every encountered distribution is challenging. Thus, out-of-distribution (OOD) PL is practically useful for enhancing the generalization ability of LLMs with limited preference feedback. This work addresses OOD PL by optimizing a general reward model through a meta-learning approach. During meta-training, a bilevel optimization algorithm is utilized to learn a reward model capable of guiding policy learning to align with human preferences across various distributions. When encountering a test distribution, the meta-test procedure conducts regularized policy optimization using the learned reward model for PL. We theoretically demonstrate the convergence rate of the bilevel optimization algorithm under reasonable assumptions. Additionally, we conduct experiments on two text generation tasks across 20 held-out domains and outperform a variety of strong baselines across various evaluation metrics.
Paper Structure (26 sections, 10 theorems, 62 equations, 6 figures, 6 tables, 2 algorithms)

This paper contains 26 sections, 10 theorems, 62 equations, 6 figures, 6 tables, 2 algorithms.

Key Result

proposition thmcounterproposition

For any outer-loop step $k \in \{0,1,2,\ldots,K-1\}$, with the outer-loop input $z_k$ and the inner-loop inputs $\{z_{k,t}\}_{t=0}^{D-1}$, the gradient $\frac{\partial \ell_{\rm PL}(\phi_k, \theta_{k,D})}{\partial \phi_k}$ takes the analytical form, where $A_k := \log \pi_{\theta_{k,D}}(y_k|x_k) - \log \pi_{{\theta}_{k,D}}(y'_k|x_k)$, $R_{k,t} := \exp \left( \frac{1}{\beta}r_{\phi_k}(x_{k,t},y_{

Figures (6)

  • Figure 1: Comparison between OOD PL and existing ID PL, i.e., PPO and DPO.
  • Figure 2: Overall training process. Meta-training optimizes the RM $\Phi$ using $K$ iterations of SGD. During testing, the RM $\Phi_{K}$ is utilized to optimize the test policy $\theta_*$.
  • Figure 3: (a)-(d) illustrate the training loss and PL accuracy $\mathcal{A}_{\rm PL}$ domain, and (e)-(h) illustrate the evaluation accuracy against the training steps for each held-out CSG distribution.
  • Figure 4: Reward on four held-out Amazon Review distributions.
  • Figure 5: Effects of reward learning w.r.t. the reward controlling factor $\beta$ on the book distribution (a-c) and on the ac,an,ba distribution (d-f).
  • ...and 1 more figures

Theorems & Definitions (22)

  • proposition thmcounterproposition
  • proof
  • theorem thmcountertheorem
  • proof
  • remark thmcounterremark
  • lemma thmcounterlemma
  • proof
  • lemma thmcounterlemma
  • proof
  • lemma thmcounterlemma
  • ...and 12 more