Understanding Forgetting in LLM Supervised Fine-Tuning and Preference Learning -- A Convex Optimization Perspective
Heshan Fernando, Han Shen, Parikshit Ram, Yi Zhou, Horst Samulowitz, Nathalie Baracaldo, Tianyi Chen
TL;DR
This work analyzes forgetting in the standard two-stage post-training of LLMs (SFT and preference alignment via DPO/RLHF) and proves that sequential optimization can incur a non-diminishing trade-off gap between objectives. It proposes a principled joint post-training framework built on alternating updates, with two algorithms, ALRIGHT and MAXRIGHT, providing convergence guarantees and practical efficiency. Empirical results on models like Llama3-8b and Pythia-1b show substantial improvements over sequential training (up to ~23% across benchmarks) while keeping overhead low, and performance competitive with or surpassing simple objective mixing. The framework offers a scalable, data-efficient path to balance alignment and task-specific fine-tuning in LLM post-training, with strong theoretical guarantees and real-world applicability.
Abstract
The post-training of LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning stage (RLHF or DPO), is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, this is suboptimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. This sequential paradigm persists largely due to its simplicity and modularity, which make it easier to implement and manage at scale despite its limitations. We theoretically prove the sub-optimality of sequential post-training and propose a practical joint post-training framework which has theoretical convergence guarantees and empirically outperforms sequential post-training framework, with up to 23% overall performance improvement across multiple LLM evaluation benchmarks, while having minimal computational overhead. Our code is available at https://github.com/heshandevaka/XRIGHT.
