Table of Contents
Fetching ...

Understanding Forgetting in LLM Supervised Fine-Tuning and Preference Learning -- A Convex Optimization Perspective

Heshan Fernando, Han Shen, Parikshit Ram, Yi Zhou, Horst Samulowitz, Nathalie Baracaldo, Tianyi Chen

TL;DR

This work analyzes forgetting in the standard two-stage post-training of LLMs (SFT and preference alignment via DPO/RLHF) and proves that sequential optimization can incur a non-diminishing trade-off gap between objectives. It proposes a principled joint post-training framework built on alternating updates, with two algorithms, ALRIGHT and MAXRIGHT, providing convergence guarantees and practical efficiency. Empirical results on models like Llama3-8b and Pythia-1b show substantial improvements over sequential training (up to ~23% across benchmarks) while keeping overhead low, and performance competitive with or surpassing simple objective mixing. The framework offers a scalable, data-efficient path to balance alignment and task-specific fine-tuning in LLM post-training, with strong theoretical guarantees and real-world applicability.

Abstract

The post-training of LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning stage (RLHF or DPO), is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, this is suboptimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. This sequential paradigm persists largely due to its simplicity and modularity, which make it easier to implement and manage at scale despite its limitations. We theoretically prove the sub-optimality of sequential post-training and propose a practical joint post-training framework which has theoretical convergence guarantees and empirically outperforms sequential post-training framework, with up to 23% overall performance improvement across multiple LLM evaluation benchmarks, while having minimal computational overhead. Our code is available at https://github.com/heshandevaka/XRIGHT.

Understanding Forgetting in LLM Supervised Fine-Tuning and Preference Learning -- A Convex Optimization Perspective

TL;DR

This work analyzes forgetting in the standard two-stage post-training of LLMs (SFT and preference alignment via DPO/RLHF) and proves that sequential optimization can incur a non-diminishing trade-off gap between objectives. It proposes a principled joint post-training framework built on alternating updates, with two algorithms, ALRIGHT and MAXRIGHT, providing convergence guarantees and practical efficiency. Empirical results on models like Llama3-8b and Pythia-1b show substantial improvements over sequential training (up to ~23% across benchmarks) while keeping overhead low, and performance competitive with or surpassing simple objective mixing. The framework offers a scalable, data-efficient path to balance alignment and task-specific fine-tuning in LLM post-training, with strong theoretical guarantees and real-world applicability.

Abstract

The post-training of LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning stage (RLHF or DPO), is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, this is suboptimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. This sequential paradigm persists largely due to its simplicity and modularity, which make it easier to implement and manage at scale despite its limitations. We theoretically prove the sub-optimality of sequential post-training and propose a practical joint post-training framework which has theoretical convergence guarantees and empirically outperforms sequential post-training framework, with up to 23% overall performance improvement across multiple LLM evaluation benchmarks, while having minimal computational overhead. Our code is available at https://github.com/heshandevaka/XRIGHT.

Paper Structure

This paper contains 30 sections, 9 theorems, 58 equations, 14 figures, 3 tables, 3 algorithms.

Key Result

Theorem 3.3

Consider Algorithm algo:seq-rlhf-sft with $T_{\text{\tiny DPO}}=T_{\text{\tiny SFT}}=T$ under Assumption ass:features. Then, there exists data $\mathcal{D}_{\text{\tiny DPO}}$ and $\mathcal{D}_{\text{\tiny SFT}}$ such that given any $\lambda\in(0,1)$, Algorithm algo:seq-rlhf-sft with any sufficientl where $\mathbb{E}[~\cdot~]$ is taken over the randomness of Algorithm algo:seq-rlhf-sft.

Figures (14)

  • Figure 1: Efficient Trade-off in RLHF and SFT Optimization.Sequential optimization (e.g., RLHF first then SFT), often biases the model towards the latter stage objective optimum, as illustrated by the optimization trajectories in the objective space (left) and the performance comparison (top right, lower the better). In contrast, simultaneous optimization of a Mix of RLHF and SFT objectives achieves a more balanced performance but requires significantly more resources (bottom right, lower the better). We propose ALRIGHT and MAXRIGHT strategies for joint RLHF and SFT optimization, offering an improved trade-off with minimal extra cost.
  • Figure 2: Toy example.(a) Sequential DPO and SFT: Model oscillates between the optima of DPO and SFT objectives in parameter space, resulting in a final trade-off that is far away from the ideal point in objective space where both DPO and SFT objective values are optimal. (b) ALRIGHT / (c) MAXRIGHT: Model directly navigates towards a point in parameter space that is reasonably optimal for both DPO and SFT objectives (average optimum), achieving a final trade-off of DPO and SFT objectives much closer to the ideal point.
  • Figure 3: Comparison in first DPO then SFT setting using pythia-1b model.Left: Training trajectories in the objective space. Right: Performance comparison across multiple evaluation metrics, including optimality gap for DPO and SFT objectives, ideal distance, runtime, and GPU utilization. The bar charts highlight the trade-offs and resource efficiency of each method for different choices of $(T_{\text{\tiny DPO}}, T_{\text{\tiny SFT}})$ or $\lambda$.
  • Figure 4: Comparison in first DPO then SFT setting using Llama3-8b model. Ealuation using MMLU, HellaSwag, SORRY-Bench, and XSTest, benchmarks, along with runtime and GPU utilization efficiencies across different post-training methods.
  • Figure 5: Comparison in first SFT then DPO setting using opt-1.3b model. Evaluation using win rate against the chosen responses 50 test set samples from HuggingFaceH4/ultrafeedback_binarized dataset judged by UltraRM-13b reward model.
  • ...and 9 more figures

Theorems & Definitions (19)

  • Remark 3.1
  • Theorem 3.3: Lower bound for sequential method performance
  • Theorem 4.1: Upper bound for alternating method performance
  • Remark 4.2
  • Remark 4.3
  • Proposition A.1: Bounded second moment of stochastic gradient
  • proof
  • Proposition A.2: Convexity of objectives
  • proof
  • proof
  • ...and 9 more