Table of Contents
Fetching ...

Proxy-RLHF: Decoupling Generation and Alignment in Large Language Model with Proxy

Yu Zhu, Chuxiong Sun, Wenfei Yang, Wenqiang Wei, Bo Tang, Tianzhu Zhang, Zhiyu Li, Shifeng Zhang, Feiyu Xiong, Jie Hu, Mingchuan yang

Abstract

Reinforcement Learning from Human Feedback (RLHF) is the prevailing approach to ensure Large Language Models (LLMs) align with human values. However, existing RLHF methods require a high computational cost, one main reason being that RLHF assigns both the generation and alignment tasks to the LLM simultaneously. In this paper, we introduce Proxy-RLHF, which decouples the generation and alignment processes of LLMs, achieving alignment with human values at a much lower computational cost. We start with a novel Markov Decision Process (MDP) designed for the alignment process and employ Reinforcement Learning (RL) to train a streamlined proxy model that oversees the token generation of the LLM, without altering the LLM itself. Experiments show that our method achieves a comparable level of alignment with only 1\% of the training parameters of other methods.

Proxy-RLHF: Decoupling Generation and Alignment in Large Language Model with Proxy

Abstract

Reinforcement Learning from Human Feedback (RLHF) is the prevailing approach to ensure Large Language Models (LLMs) align with human values. However, existing RLHF methods require a high computational cost, one main reason being that RLHF assigns both the generation and alignment tasks to the LLM simultaneously. In this paper, we introduce Proxy-RLHF, which decouples the generation and alignment processes of LLMs, achieving alignment with human values at a much lower computational cost. We start with a novel Markov Decision Process (MDP) designed for the alignment process and employ Reinforcement Learning (RL) to train a streamlined proxy model that oversees the token generation of the LLM, without altering the LLM itself. Experiments show that our method achieves a comparable level of alignment with only 1\% of the training parameters of other methods.
Paper Structure (17 sections, 4 equations, 5 figures, 4 tables)

This paper contains 17 sections, 4 equations, 5 figures, 4 tables.

Figures (5)

  • Figure 1: Demonstration of how the proxy model works. The proxy model is responsible for supervising the generation of the LLM, deciding whether to accept the latest token generated by the LLM. By accepting tokens that align with human values and rejecting those that do not, it ensures that the final generation results are aligned with human values.
  • Figure 2: (a) The reward distribution on the test set for SFT and Ours, where scores are obtained from the reward model. (b) The win rate of Ours, DPO, RLHF, and BON against the SFT model, where the win rate is determined by pair-wise comparison from GPT-4. We use greedy sampling for all methods above and set n=32 in BON.
  • Figure 3: (a) The average score given by the reward model on the test set, for models corresponding to different $p_t$, after completing one round on the training set. (b) The average score corresponding to different temperatures, after completing one round on the training set.
  • Figure 4: (a)-(d) The reward distribution given by the reward model on the test set for the SFT model and models corresponding to different temperatures after completing one round on the training set. (a) temperature$=0.25$, (b) temperature$=0.5$, (c) temperature$=0.75$, (d) temperature$=1.0$. (e)-(h) The average score on the test set for models with different temperatures on the first 2k train data. (e) temperature$=0.25$, (f) temperature$=0.5$, (g) temperature$=0.75$, (h) temperature$=1.0$
  • Figure 5: (a)-(d) The reward distribution given by the reward model on the test set for the SFT model and models corresponding to different $p_t$ after completing one round on the training set. (a) $p_t=0.1$, (b) $p_t=0.01$, (c) $p_t=0.001$, (d) $p_t=0.0001$. (e)-(h) The average score on the test set for models with different $p_t$ on the first 2k train data. (e) $p_t=0.1$, (f) $p_t=0.01$, (g) $p_t=0.001$, (h) $p_t=0.0001$