Table of Contents
Fetching ...

Selective Self-Rehearsal: A Fine-Tuning Approach to Improve Generalization in Large Language Models

Sonam Gupta, Yatin Nandwani, Asaf Yehudai, Mayank Mishra, Gaurav Pandey, Dinesh Raghu, Sachindra Joshi

TL;DR

Selective Self-Rehearsal (SSR) tackles overfitting during fine-tuning of large language models by selectively training on the model's own correct outputs, identified with an LLM judge, rather than relying solely on gold labels. This approach preserves the base model's reasoning and general capabilities while still learning new tasks, achieving SFT-level performance with substantially better generalization across in-domain and out-domain datasets. Empirical results show SSR reduces average performance drop to around $2\%$ on standard benchmarks, compared with up to $16.7\%$ for traditional supervised fine-tuning. The method is particularly effective for content-grounded QA/conversation, and the authors provide extensive analyses, including human evaluations and generalization tests, underscoring SSR's practical impact for robust, domain-general LLM fine-tuning.

Abstract

Fine-tuning Large Language Models (LLMs) on specific datasets is a common practice to improve performance on target tasks. However, this performance gain often leads to overfitting, where the model becomes too specialized in either the task or the characteristics of the training data, resulting in a loss of generalization. This paper introduces Selective Self-Rehearsal (SSR), a fine-tuning approach that achieves performance comparable to the standard supervised fine-tuning (SFT) while improving generalization. SSR leverages the fact that there can be multiple valid responses to a query. By utilizing the model's correct responses, SSR reduces model specialization during the fine-tuning stage. SSR first identifies the correct model responses from the training set by deploying an appropriate LLM as a judge. Then, it fine-tunes the model using the correct model responses and the gold response for the remaining samples. The effectiveness of SSR is demonstrated through experiments on the task of identifying unanswerable queries across various datasets. The results show that standard SFT can lead to an average performance drop of up to $16.7\%$ on multiple benchmarks, such as MMLU and TruthfulQA. In contrast, SSR results in close to $2\%$ drop on average, indicating better generalization capabilities compared to standard SFT.

Selective Self-Rehearsal: A Fine-Tuning Approach to Improve Generalization in Large Language Models

TL;DR

Selective Self-Rehearsal (SSR) tackles overfitting during fine-tuning of large language models by selectively training on the model's own correct outputs, identified with an LLM judge, rather than relying solely on gold labels. This approach preserves the base model's reasoning and general capabilities while still learning new tasks, achieving SFT-level performance with substantially better generalization across in-domain and out-domain datasets. Empirical results show SSR reduces average performance drop to around on standard benchmarks, compared with up to for traditional supervised fine-tuning. The method is particularly effective for content-grounded QA/conversation, and the authors provide extensive analyses, including human evaluations and generalization tests, underscoring SSR's practical impact for robust, domain-general LLM fine-tuning.

Abstract

Fine-tuning Large Language Models (LLMs) on specific datasets is a common practice to improve performance on target tasks. However, this performance gain often leads to overfitting, where the model becomes too specialized in either the task or the characteristics of the training data, resulting in a loss of generalization. This paper introduces Selective Self-Rehearsal (SSR), a fine-tuning approach that achieves performance comparable to the standard supervised fine-tuning (SFT) while improving generalization. SSR leverages the fact that there can be multiple valid responses to a query. By utilizing the model's correct responses, SSR reduces model specialization during the fine-tuning stage. SSR first identifies the correct model responses from the training set by deploying an appropriate LLM as a judge. Then, it fine-tunes the model using the correct model responses and the gold response for the remaining samples. The effectiveness of SSR is demonstrated through experiments on the task of identifying unanswerable queries across various datasets. The results show that standard SFT can lead to an average performance drop of up to on multiple benchmarks, such as MMLU and TruthfulQA. In contrast, SSR results in close to drop on average, indicating better generalization capabilities compared to standard SFT.
Paper Structure (18 sections, 2 equations, 8 figures, 7 tables)

This paper contains 18 sections, 2 equations, 8 figures, 7 tables.

Figures (8)

  • Figure 1: Histogram of the log probability assigned by Mistral-7B-Instruct-v0.2 to the gold responses and its own predictions. The distribution is based on 5,000 examples from the MD2D training data.
  • Figure 2: An overview of our proposed approach. In the example, the document and question are part of the input $x_i$, and the response is the output $y_i$. The llm-judge decides whether the base model output $\mathcal{M}_{\theta_0}(x_i)$ is acceptable or not. If yes, then we use it for loss computation (subset $\mathcal{R}$); otherwise, we use $y_i$ (subset $\mathcal{G}$). See eqn. \ref{['eqn:ssrloss']}.
  • Figure 3: Comparision between the confusion Matrix of the MuSiQue dataset obtained using the base model (a), and the models fine-tuned on MD2D (b and c) and NQ (d and e) using SSR and standard SFT.
  • Figure 4: Mistral-single-turn prompt
  • Figure 5: Mistral-multi-turn prompt
  • ...and 3 more figures