Table of Contents
Fetching ...

MetaRM: Shifted Distributions Alignment via Meta-Learning

Shihan Dou, Yan Liu, Enyu Zhou, Tianlong Li, Haoxiang Jia, Limao Xiong, Xin Zhao, Junjie Ye, Rui Zheng, Tao Gui, Qi Zhang, Xuanjing Huang

TL;DR

MetaRM addresses the distribution-shift problem in reward models during iterative RLHF by meta-learning a shifted-distribution-aware RM. It performs a one-step gradient ascent on a difference loss $\,\mathcal{J}_{\theta}$ computed from a meta-set drawn from the shifted environment, producing $ heta'$, then optimizes the vanilla preference loss on original data and updates $ heta$ with a descent. The method yields a gradient that includes a dot-product term between the meta and vanilla gradients, encouraging data that are jointly informative for both objectives. Across dialogue and summarization tasks on HH-RLHF and OOD datasets, MetaRM improves RM discrimination and language model performance over SFT, PPO, and DPO, particularly in early RLHF rounds, and demonstrates robustness to distribution shift without requiring additional labeled data. These results suggest MetaRM can enable more reliable, label-efficient adaptation of reward models during iterative RLHF with enhanced sensitivity to subtle, shifted-distribution differences.

Abstract

The success of Reinforcement Learning from Human Feedback (RLHF) in language model alignment is critically dependent on the capability of the reward model (RM). However, as the training process progresses, the output distribution of the policy model shifts, leading to the RM's reduced ability to distinguish between responses. This issue is further compounded when the RM, trained on a specific data distribution, struggles to generalize to examples outside of that distribution. These two issues can be united as a challenge posed by the shifted distribution of the environment. To surmount this challenge, we introduce MetaRM, a method leveraging meta-learning to align the RM with the shifted environment distribution. MetaRM is designed to train the RM by minimizing data loss, particularly for data that can improve the differentiation ability to examples of the shifted target distribution. Extensive experiments demonstrate that MetaRM significantly improves the RM's distinguishing ability in iterative RLHF optimization, and also provides the capacity to identify subtle differences in out-of-distribution samples.

MetaRM: Shifted Distributions Alignment via Meta-Learning

TL;DR

MetaRM addresses the distribution-shift problem in reward models during iterative RLHF by meta-learning a shifted-distribution-aware RM. It performs a one-step gradient ascent on a difference loss computed from a meta-set drawn from the shifted environment, producing , then optimizes the vanilla preference loss on original data and updates with a descent. The method yields a gradient that includes a dot-product term between the meta and vanilla gradients, encouraging data that are jointly informative for both objectives. Across dialogue and summarization tasks on HH-RLHF and OOD datasets, MetaRM improves RM discrimination and language model performance over SFT, PPO, and DPO, particularly in early RLHF rounds, and demonstrates robustness to distribution shift without requiring additional labeled data. These results suggest MetaRM can enable more reliable, label-efficient adaptation of reward models during iterative RLHF with enhanced sensitivity to subtle, shifted-distribution differences.

Abstract

The success of Reinforcement Learning from Human Feedback (RLHF) in language model alignment is critically dependent on the capability of the reward model (RM). However, as the training process progresses, the output distribution of the policy model shifts, leading to the RM's reduced ability to distinguish between responses. This issue is further compounded when the RM, trained on a specific data distribution, struggles to generalize to examples outside of that distribution. These two issues can be united as a challenge posed by the shifted distribution of the environment. To surmount this challenge, we introduce MetaRM, a method leveraging meta-learning to align the RM with the shifted environment distribution. MetaRM is designed to train the RM by minimizing data loss, particularly for data that can improve the differentiation ability to examples of the shifted target distribution. Extensive experiments demonstrate that MetaRM significantly improves the RM's distinguishing ability in iterative RLHF optimization, and also provides the capacity to identify subtle differences in out-of-distribution samples.
Paper Structure (19 sections, 8 equations, 6 figures, 2 tables, 1 algorithm)

This paper contains 19 sections, 8 equations, 6 figures, 2 tables, 1 algorithm.

Figures (6)

  • Figure 1: Variance of reward difference distribution. We select 1K queries randomly and for each query, we sample two responses from the model output distribution and compute the difference between these rewards, to obtain the reward difference distribution. As the RL training process progresses, the model output distribution shifts, causing the RM to fail to distinguish between responses, resulting in a decreasing variance. These indicate that the RM struggles to capture subtle differences between responses under conditions of shifting environment distribution.
  • Figure 2: The pipeline of our proposed approach MetaRM. MetaRM contains four simple steps: 1. Compute the difference loss on responses sampled from the shifted distribution. 2. Calculate the gradient of this loss wrt. the RM parameters $\theta_t$ and adjust the parameters according to the ascent direction. 3. Compute the vanilla loss on the original preference pairs using the updated parameters $\theta_{t}^{'}$. 4. Calculate the gradient of the vanilla loss wrt. $\theta_{t}^{'}$ and optimize the original parameters $\theta$ following the descent direction.
  • Figure 3: The results on the out-of-distribution task compared to SFT and vanilla PPO. The results show that our method outperforms other baselines by adapting the reward model to the new distribution.
  • Figure 4: The accuracy curves for the reward model training phase on the valid set. The curves show that MetaRM can achieve similar accuracy compared to the original RM training way. This indicates that our method can maintain the RM's ability to modeling human preferences in the gradient descent, while making it adapt to the new distribution by using the meta-process.
  • Figure 5: Reward difference distributions for the original RM's training way and MetaRM, which normalized to a range of zero to one. It indicates that MetaRM can enhance the RM's ability to distinguish samples from a shifted environment distribution through meta-learning.
  • ...and 1 more figures