Selective Self-Rehearsal: A Fine-Tuning Approach to Improve Generalization in Large Language Models
Sonam Gupta, Yatin Nandwani, Asaf Yehudai, Mayank Mishra, Gaurav Pandey, Dinesh Raghu, Sachindra Joshi
TL;DR
Selective Self-Rehearsal (SSR) tackles overfitting during fine-tuning of large language models by selectively training on the model's own correct outputs, identified with an LLM judge, rather than relying solely on gold labels. This approach preserves the base model's reasoning and general capabilities while still learning new tasks, achieving SFT-level performance with substantially better generalization across in-domain and out-domain datasets. Empirical results show SSR reduces average performance drop to around $2\%$ on standard benchmarks, compared with up to $16.7\%$ for traditional supervised fine-tuning. The method is particularly effective for content-grounded QA/conversation, and the authors provide extensive analyses, including human evaluations and generalization tests, underscoring SSR's practical impact for robust, domain-general LLM fine-tuning.
Abstract
Fine-tuning Large Language Models (LLMs) on specific datasets is a common practice to improve performance on target tasks. However, this performance gain often leads to overfitting, where the model becomes too specialized in either the task or the characteristics of the training data, resulting in a loss of generalization. This paper introduces Selective Self-Rehearsal (SSR), a fine-tuning approach that achieves performance comparable to the standard supervised fine-tuning (SFT) while improving generalization. SSR leverages the fact that there can be multiple valid responses to a query. By utilizing the model's correct responses, SSR reduces model specialization during the fine-tuning stage. SSR first identifies the correct model responses from the training set by deploying an appropriate LLM as a judge. Then, it fine-tunes the model using the correct model responses and the gold response for the remaining samples. The effectiveness of SSR is demonstrated through experiments on the task of identifying unanswerable queries across various datasets. The results show that standard SFT can lead to an average performance drop of up to $16.7\%$ on multiple benchmarks, such as MMLU and TruthfulQA. In contrast, SSR results in close to $2\%$ drop on average, indicating better generalization capabilities compared to standard SFT.
