Table of Contents
Fetching ...

Learning Goal-Conditioned Representations for Language Reward Models

Vaskar Nath, Dylan Slack, Jeff Da, Yuntao Ma, Hugh Zhang, Spencer Whitehead, Sean Hendryx

TL;DR

This work proposes training reward models in a contrastive, goal-conditioned fashion by increasing the representation similarity of future states along sampled preferred trajectories and decreasing the similarity along randomly sampled dispreferred trajectories, and finds that these representations can perform fine-grained control by conditioning on desired future goal-states.

Abstract

Techniques that learn improved representations via offline data or self-supervised objectives have shown impressive results in traditional reinforcement learning (RL). Nevertheless, it is unclear how improved representation learning can benefit reinforcement learning from human feedback (RLHF) on language models (LMs). In this work, we propose training reward models (RMs) in a contrastive, $\textit{goal-conditioned}$ fashion by increasing the representation similarity of future states along sampled preferred trajectories and decreasing the similarity along randomly sampled dispreferred trajectories. This objective significantly improves RM performance by up to 0.09 AUROC across challenging benchmarks, such as MATH and GSM8k. These findings extend to general alignment as well -- on the Helpful-Harmless dataset, we observe $2.3\%$ increase in accuracy. Beyond improving reward model performance, we show this way of training RM representations enables improved $\textit{steerability}$ because it allows us to evaluate the likelihood of an action achieving a particular goal-state (e.g., whether a solution is correct or helpful). Leveraging this insight, we find that we can filter up to $55\%$ of generated tokens during majority voting by discarding trajectories likely to end up in an "incorrect" state, which leads to significant cost savings. We additionally find that these representations can perform fine-grained control by conditioning on desired future goal-states. For example, we show that steering a Llama 3 model towards helpful generations with our approach improves helpfulness by $9.6\%$ over a supervised-fine-tuning trained baseline. Similarly, steering the model towards complex generations improves complexity by $21.6\%$ over the baseline. Overall, we find that training RMs in this contrastive, goal-conditioned fashion significantly improves performance and enables model steerability.

Learning Goal-Conditioned Representations for Language Reward Models

TL;DR

This work proposes training reward models in a contrastive, goal-conditioned fashion by increasing the representation similarity of future states along sampled preferred trajectories and decreasing the similarity along randomly sampled dispreferred trajectories, and finds that these representations can perform fine-grained control by conditioning on desired future goal-states.

Abstract

Techniques that learn improved representations via offline data or self-supervised objectives have shown impressive results in traditional reinforcement learning (RL). Nevertheless, it is unclear how improved representation learning can benefit reinforcement learning from human feedback (RLHF) on language models (LMs). In this work, we propose training reward models (RMs) in a contrastive, fashion by increasing the representation similarity of future states along sampled preferred trajectories and decreasing the similarity along randomly sampled dispreferred trajectories. This objective significantly improves RM performance by up to 0.09 AUROC across challenging benchmarks, such as MATH and GSM8k. These findings extend to general alignment as well -- on the Helpful-Harmless dataset, we observe increase in accuracy. Beyond improving reward model performance, we show this way of training RM representations enables improved because it allows us to evaluate the likelihood of an action achieving a particular goal-state (e.g., whether a solution is correct or helpful). Leveraging this insight, we find that we can filter up to of generated tokens during majority voting by discarding trajectories likely to end up in an "incorrect" state, which leads to significant cost savings. We additionally find that these representations can perform fine-grained control by conditioning on desired future goal-states. For example, we show that steering a Llama 3 model towards helpful generations with our approach improves helpfulness by over a supervised-fine-tuning trained baseline. Similarly, steering the model towards complex generations improves complexity by over the baseline. Overall, we find that training RMs in this contrastive, goal-conditioned fashion significantly improves performance and enables model steerability.
Paper Structure (37 sections, 8 equations, 10 figures, 15 tables)

This paper contains 37 sections, 8 equations, 10 figures, 15 tables.

Figures (10)

  • Figure 1: Overview of contrastive goal-conditioned learning for text. Pictured is a prompt with a preferred and dispreferred response. Both source state tokens (ten) for the positive and negative trajectory are sampled from the preferred response. For illustrative purposes, the positve and negative source states are sampled as the same token, but in practice they can be different. The positive goal state is sampled as some future token (subtract) from the preferred response, and the negative goal state is sampled from any token (add) from the dispreferred response. The corresponding representations are retrieved from the last hidden state of the reward model. The training objective is then to maximize and minimize the similarity of the positive and negative representation pairs, respectively.
  • Figure 2: AUROC scores comparing the baseline Codellama 7b Reward vs. our proposed method Q-Function 7b Reward on the rewards attributed to the base-model greedy generations across several math benchmarks.
  • Figure 3: AUROC scores on the rewards attributed to partial base-model generations across 50 samples on GSM8k and MATH. The error bars depict the 95% confidence intervals (with sample size $n=50$) at each percentile of generation considered. The Q-Function reward model has incremental increase in performance with more information, whereas, the traditional reward model's performance is a lot more varied in attributing intermediate rewards.
  • Figure 4: Two examples from the GSM8K dataset that was filtered via the Q-value pruning. The token level Q-values are portrayed as a heat map where the colors red and blue represents scores close to $-1$ and $1$, respectively. Both examples illustrate that the Q-values pinpoint the exact logical error in reasoning. The full version of these examples can be found in Appendix \ref{['app:q_val_heatmap']}
  • Figure 5: We plot the average proportion of non-filtered responses that are correct and the average number of completions (out of a total 50) that are filtered, as we vary the sample size (1, 10, 100, 1K, 10K, and 30K) of dis-preferred, preferred, and corrupted examples when constructing the goal state. All of the examples are retrieved from the preference ranking training dataset. We plot the performance metrics for the MATH and GSM8K test sets. Each of the test sets has 50 completions per problem that is generated by the base model and we filter any completion for which the Q-value for any token in the completion sequence is below 0. At each sample size, we run 3 independent random sample from the full set of preferred and dis-preferred completions to construct 3 independent goal states, with 95% CI.
  • ...and 5 more figures