Table of Contents
Fetching ...

Vanishing Gradients in Reinforcement Finetuning of Language Models

Noam Razin, Hattie Zhou, Omid Saremi, Vimal Thilak, Arwen Bradley, Preetum Nakkiran, Joshua Susskind, Etai Littwin

TL;DR

It is proved that the expected gradient for an input vanishes when its reward standard deviation under the model is small, even if the expected reward is far from optimal, which emphasizes that being mindful for inputs whose expected gradient vanishes is crucial for successful execution of RFT.

Abstract

Pretrained language models are commonly aligned with human preferences and downstream tasks via reinforcement finetuning (RFT), which refers to maximizing a (possibly learned) reward function using policy gradient algorithms. This work identifies a fundamental optimization obstacle in RFT: we prove that the expected gradient for an input vanishes when its reward standard deviation under the model is small, even if the expected reward is far from optimal. Through experiments on an RFT benchmark and controlled environments, as well as a theoretical analysis, we then demonstrate that vanishing gradients due to small reward standard deviation are prevalent and detrimental, leading to extremely slow reward maximization. Lastly, we explore ways to overcome vanishing gradients in RFT. We find the common practice of an initial supervised finetuning (SFT) phase to be the most promising candidate, which sheds light on its importance in an RFT pipeline. Moreover, we show that a relatively small number of SFT optimization steps on as few as 1% of the input samples can suffice, indicating that the initial SFT phase need not be expensive in terms of compute and data labeling efforts. Overall, our results emphasize that being mindful for inputs whose expected gradient vanishes, as measured by the reward standard deviation, is crucial for successful execution of RFT.

Vanishing Gradients in Reinforcement Finetuning of Language Models

TL;DR

It is proved that the expected gradient for an input vanishes when its reward standard deviation under the model is small, even if the expected reward is far from optimal, which emphasizes that being mindful for inputs whose expected gradient vanishes is crucial for successful execution of RFT.

Abstract

Pretrained language models are commonly aligned with human preferences and downstream tasks via reinforcement finetuning (RFT), which refers to maximizing a (possibly learned) reward function using policy gradient algorithms. This work identifies a fundamental optimization obstacle in RFT: we prove that the expected gradient for an input vanishes when its reward standard deviation under the model is small, even if the expected reward is far from optimal. Through experiments on an RFT benchmark and controlled environments, as well as a theoretical analysis, we then demonstrate that vanishing gradients due to small reward standard deviation are prevalent and detrimental, leading to extremely slow reward maximization. Lastly, we explore ways to overcome vanishing gradients in RFT. We find the common practice of an initial supervised finetuning (SFT) phase to be the most promising candidate, which sheds light on its importance in an RFT pipeline. Moreover, we show that a relatively small number of SFT optimization steps on as few as 1% of the input samples can suffice, indicating that the initial SFT phase need not be expensive in terms of compute and data labeling efforts. Overall, our results emphasize that being mindful for inputs whose expected gradient vanishes, as measured by the reward standard deviation, is crucial for successful execution of RFT.
Paper Structure (28 sections, 4 theorems, 85 equations, 18 figures, 7 tables)

This paper contains 28 sections, 4 theorems, 85 equations, 18 figures, 7 tables.

Key Result

Theorem 1

For parameters $\theta \in {\mathbb R}^P$ and input ${\mathbf x} \in {\mathcal{X}}^{L_{in}}$, denote the reward standard deviation of ${\mathbf x}$ under the model by $\mathop{\mathrm{STD}}\limits\nolimits_{{\mathbf y} \sim p_{\theta} \brk{ \cdot | {\mathbf x}} } \brk[s]*{ r ({\mathbf x}, {\mathbf y where $\gamma ({\mathbf x}; \theta) := \max\nolimits_{l \in \{1, \ldots, L_{out}\}, {\mathbf y}_{\l

Figures (18)

  • Figure 1: Inputs with small reward standard deviation under the pretrained model, i.e. with vanishing expected gradient, are prevalent in the GRUE benchmark. For randomly chosen subsets of $5000$ train samples from the NarrativeQA, ToTTo, and IMDB datasets, presented are the reward means and standard deviations (estimated based on ten generations per input) under the pretrained, RFT, and SFT models. The samples are ordered according to their pretrain reward standard deviation, with each marker representing the reward mean and standard deviation (depicted by color) of an individual sample. Notice that a significant number of samples from NarrativeQA and ToTTo have small pretrain reward standard deviation, while their reward mean is low. Accordingly, RFT struggles to improve the reward of these inputs, especially compared to SFT. In contrast, IMDB does not suffer from this issue, and the effect of RFT and SFT over it is more similar. See \ref{['fig:grue_reward_mean_std_scatter_fig_app']} in \ref{['app:experiments:further']} for identical experiments with the remaining GRUE datasets.
  • Figure 2: RFT performance (relative to SFT) is worse when inputs with small reward standard deviation are prevalent. Per dataset, the difference between the mean reward achieved by RFT and SFT is plotted against the $10$'th percentile of the reward standard deviation under the pretrained model. Means are taken over five runs and error bars (indiscernible) mark standard deviations. We exclude inputs with near-optimal reward mean under the pretrained model (higher than $0.9$) when computing the percentiles, since small reward standard deviation is only problematic if the reward mean is not high to begin with. Observe that, the lower the pretrain reward standard deviation percentile is, i.e. the more train samples have small pretrain reward standard deviation, the worse the reward that RFT achieves relative to SFT.
  • Figure 3: RFT struggles to maximize the reward over inputs with small reward standard deviation under the pretrained model, i.e. inputs with vanishing expected gradient, even with perfect exploration. On the contrary, SFT easily leads to maximal reward. For the controlled environments described in \ref{['sec:evidence:controlled']}, in which RFT has access to expected gradients, displayed are the train reward (top), reward standard deviation (middle), and gradient norm (bottom) for RFT and SFT throughout optimization, separately for train samples with small and train samples with large pretrain reward standard deviation. See \ref{['fig:sft_vs_rft_controlled_exps_sgd']} in \ref{['app:experiments:further']} for an identical experiment with stochastic gradient descent instead of Adam.
  • Figure 4: On datasets in which RFT suffers from vanishing gradients, a few initial SFT optimization steps on a small number of labeled inputs substantially boost the efficacy of RFT. For the NarrativeQA dataset, reported metrics are based on the mean train reward achieved when performing RFT after an initial SFT phase with various percentages of optimization steps and labeled inputs (over three random seeds). A “full" SFT phase refers to $100\%$ of the steps and labeled inputs used by ramamurthy2023reinforcement. Observe that the number of optimization steps and labeled inputs can be greatly reduced without causing a significant degradation in reward (left). Furthermore, RFT becomes roughly $5$ to $18$ times more potent after the initial SFT phase (right). We refer to \ref{['app:experiments:further']} for analogous plots reporting metrics based on the mean test reward (\ref{['fig:narqa_partial_sft_app']}), as well as identical experiments on the ToTTo and CommonGen datasets (\ref{['fig:totto_partial_sft_app', 'fig:commongen_partial_sft_app']}, respectively).
  • Figure 5: Representative example of an input with small (top) and an input with large (bottom) reward standard deviation under the pretrained model from the NarrativeQA dataset.
  • ...and 13 more figures

Theorems & Definitions (5)

  • Theorem 1
  • Proposition 1
  • Proposition 2
  • Theorem 2
  • proof : Proof sketch (full proof in \ref{['app:proofs:linear_rft_sft_optim_separation']})