How to Leverage Demonstration Data in Alignment for Large Language Model? A Self-Imitation Learning Perspective
Teng Xiao, Mingxiao Li, Yige Yuan, Huaisheng Zhu, Chao Cui, Vasant G Honavar
TL;DR
This work tackles the problem of aligning large language models using offline human demonstrations without costly preference labeling or adversarial training. It introduces generalized self-imitation learning (GSIL), reframing imitation learning as a density-ratio estimation task and deriving a self-normalized, closed-form policy update that eliminates RL loops. GSIL supports a family of density-ratio losses and combines real demonstration data with self-generated data to drive learning, achieving significant improvements over SFT and SPIN, and even surpassing some DPO results on math, coding, and reasoning benchmarks, while also enhancing safety alignment. The approach offers a practical, scalable path for demonstration-based alignment and provides a unified perspective bridging imitation learning and density-ratio estimation for offline LLM fine-tuning.
Abstract
This paper introduces a novel generalized self-imitation learning ($\textbf{GSIL}$) framework, which effectively and efficiently aligns large language models with offline demonstration data. We develop $\textbf{GSIL}$ by deriving a surrogate objective of imitation learning with density ratio estimates, facilitating the use of self-generated data and optimizing the imitation learning objective with simple classification losses. $\textbf{GSIL}$ eliminates the need for complex adversarial training in standard imitation learning, achieving lightweight and efficient fine-tuning for large language models. In addition, $\textbf{GSIL}$ encompasses a family of offline losses parameterized by a general class of convex functions for density ratio estimation and enables a unified view for alignment with demonstration data. Extensive experiments show that $\textbf{GSIL}$ consistently and significantly outperforms baselines in many challenging benchmarks, such as coding (HuamnEval), mathematical reasoning (GSM8K) and instruction-following benchmark (MT-Bench).
