Table of Contents
Fetching ...

Imitating Language via Scalable Inverse Reinforcement Learning

Markus Wulfmeier, Michael Bloesch, Nino Vieillard, Arun Ahuja, Jorg Bornschein, Sandy Huang, Artem Sokolov, Matt Barnes, Guillaume Desjardins, Alex Bewley, Sarah Maria Elisabeth Bechtle, Jost Tobias Springenberg, Nikola Momchev, Olivier Bachem, Matthieu Geist, Martin Riedmiller

TL;DR

The paper tackles the challenge that MLE-based language model fine-tuning may underutilize the sequential structure of autoregressive generation. It reframes imitation learning as inverse reinforcement learning and derives a TD-regularized extension of inverse soft Q-learning that connects to MLE via distribution matching. The authors introduce a principled offline IQLearn objective along with adversarial (GAIL) and non-adversarial variants, and demonstrate via experiments on T5 and PaLM2 models across GSM8k, XSUM, TLDR, and WMT22 that IRL methods can achieve equal or better task performance with noticeably increased generation diversity, often with lower compute when trained offline. Reward analysis shows that IRL-extracted rewards correlate with task performance, suggesting potential for improved reward design and smoother integration with RLHF stages. Overall, the work provides a scalable, data-efficient alternative to purely supervised fine-tuning and offers actionable insights for leveraging IRL in the LLM training pipeline.

Abstract

The majority of language model training builds on imitation learning. It covers pretraining, supervised fine-tuning, and affects the starting conditions for reinforcement learning from human feedback (RLHF). The simplicity and scalability of maximum likelihood estimation (MLE) for next token prediction led to its role as predominant paradigm. However, the broader field of imitation learning can more effectively utilize the sequential structure underlying autoregressive generation. We focus on investigating the inverse reinforcement learning (IRL) perspective to imitation, extracting rewards and directly optimizing sequences instead of individual token likelihoods and evaluate its benefits for fine-tuning large language models. We provide a new angle, reformulating inverse soft-Q-learning as a temporal difference regularized extension of MLE. This creates a principled connection between MLE and IRL and allows trading off added complexity with increased performance and diversity of generations in the supervised fine-tuning (SFT) setting. We find clear advantages for IRL-based imitation, in particular for retaining diversity while maximizing task performance, rendering IRL a strong alternative on fixed SFT datasets even without online data generation. Our analysis of IRL-extracted reward functions further indicates benefits for more robust reward functions via tighter integration of supervised and preference-based LLM post-training.

Imitating Language via Scalable Inverse Reinforcement Learning

TL;DR

The paper tackles the challenge that MLE-based language model fine-tuning may underutilize the sequential structure of autoregressive generation. It reframes imitation learning as inverse reinforcement learning and derives a TD-regularized extension of inverse soft Q-learning that connects to MLE via distribution matching. The authors introduce a principled offline IQLearn objective along with adversarial (GAIL) and non-adversarial variants, and demonstrate via experiments on T5 and PaLM2 models across GSM8k, XSUM, TLDR, and WMT22 that IRL methods can achieve equal or better task performance with noticeably increased generation diversity, often with lower compute when trained offline. Reward analysis shows that IRL-extracted rewards correlate with task performance, suggesting potential for improved reward design and smoother integration with RLHF stages. Overall, the work provides a scalable, data-efficient alternative to purely supervised fine-tuning and offers actionable insights for leveraging IRL in the LLM training pipeline.

Abstract

The majority of language model training builds on imitation learning. It covers pretraining, supervised fine-tuning, and affects the starting conditions for reinforcement learning from human feedback (RLHF). The simplicity and scalability of maximum likelihood estimation (MLE) for next token prediction led to its role as predominant paradigm. However, the broader field of imitation learning can more effectively utilize the sequential structure underlying autoregressive generation. We focus on investigating the inverse reinforcement learning (IRL) perspective to imitation, extracting rewards and directly optimizing sequences instead of individual token likelihoods and evaluate its benefits for fine-tuning large language models. We provide a new angle, reformulating inverse soft-Q-learning as a temporal difference regularized extension of MLE. This creates a principled connection between MLE and IRL and allows trading off added complexity with increased performance and diversity of generations in the supervised fine-tuning (SFT) setting. We find clear advantages for IRL-based imitation, in particular for retaining diversity while maximizing task performance, rendering IRL a strong alternative on fixed SFT datasets even without online data generation. Our analysis of IRL-extracted reward functions further indicates benefits for more robust reward functions via tighter integration of supervised and preference-based LLM post-training.
Paper Structure (33 sections, 17 equations, 12 figures, 5 tables)

This paper contains 33 sections, 17 equations, 12 figures, 5 tables.

Figures (12)

  • Figure 1: Data usage and optimization flow in MLE, offline and online IRL. Independent of the method, current models use the history of past tokens to predict the next. However, MLE purely optimizes the current output for exact matching the corresponding datapoint while IRL-based methods take into account the impact on future tokens. Online optimization additionally conditions on past model generations rather than the original dataset. Grey and blue objects respectively represent training data and model generations. The impact of future datapoints is often indirect and mediated via learned functions (e.g. the discriminator in GAIL ho2016generative and the Q-function in IQLearn garg2022iqlearn).
  • Figure 2: GSM8k results for fine-tuning with MLE, IQLearn, and GAIL across different regularization strengths. In particular MLE shows strong performance reduction with higher entropy cost. Larger models demonstrate higher performance but also stronger self similarity across generations, rendering effective trading of between task performance and diversity highly relevant. Error bars indicate the standard error of the mean after repeating the experiment with 3 different seeds.
  • Figure 3: XSUM results for models trained with MLE, IQLearn, and GAIL across different regularization strengths. ROUGE 1 and ROUGE 2 are used as performance metrics on the x-axes with Self-BLEU as diversity measure on the y-axis. Entropy regularizing large MLE and GAIL trained models with 0.1 leads to catastrophic results outside the limits of the plot. Figure \ref{['fig:xsum-selfblue-rougexsum']} in the appendix shows the corresponding plots for ROUGE-LSUM.
  • Figure 4: PaLM2 results for various sampling temperatures with MLE and IQLearn. Left: GSM8k, Mid: TLDR, Right: WMT22, including beam search. By propagating sequence information during training, IQLearn reduces inference time dependency on beam search for improving performance.
  • Figure 5: Left: performance of offline and online inverse RL performance with online ratio describing the ratio of offline data used. Right: diversity of model generations. While only showing limited gains in performance, diversity clearly improves.
  • ...and 7 more figures