Improving Generalization of Alignment with Human Preferences through Group Invariant Learning
Rui Zheng, Wei Shen, Yuan Hua, Wenbin Lai, Shihan Dou, Yuhao Zhou, Zhiheng Xi, Xiao Wang, Haoran Huang, Tao Gui, Qi Zhang, Xuanjing Huang
TL;DR
The paper tackles the generalization gap in RLHF by introducing group invariant learning (GIL) that automatically partitions data into groups and enforces invariant policy performance across them. It jointly learns group labels and a robust policy: Stage 1 infers group assignments via a critic-based predictor that maximizes return variance, and Stage 2 optimizes a policy with a variance-based regularizer, guided by an adaptive KL penalty. The approach integrates a distributionally robust objective with a dynamic regularization that strengthens learning on hard groups while curbing reward hacking on easy groups. Empirical results on dialogue and summarization tasks show improved training stability and superior generalization to out-of-distribution data compared with PPO-based baselines and DPO, with consistent human and GPT-4 evaluations. Overall, the method advances RLHF by ensuring more uniform, reliable alignment to human preferences in diverse data regimes.
Abstract
The success of AI assistants based on language models (LLMs) hinges crucially on Reinforcement Learning from Human Feedback (RLHF), which enables the generation of responses more aligned with human preferences. As universal AI assistants, there's a growing expectation for them to perform consistently across various domains. However, previous work shows that Reinforcement Learning (RL) often exploits shortcuts to attain high rewards and overlooks challenging samples. This focus on quick reward gains undermines both the stability in training and the model's ability to generalize to new, unseen data. In this work, we propose a novel approach that can learn a consistent policy via RL across various data groups or domains. Given the challenges associated with acquiring group annotations, our method automatically classifies data into different groups, deliberately maximizing performance variance. Then, we optimize the policy to perform well on challenging groups. Lastly, leveraging the established groups, our approach adaptively adjusts the exploration space, allocating more learning capacity to more challenging data and preventing the model from over-optimizing on simpler data. Experimental results indicate that our approach significantly enhances training stability and model generalization.
