Table of Contents
Fetching ...

Selective Self-to-Supervised Fine-Tuning for Generalization in Large Language Models

Sonam Gupta, Yatin Nandwani, Asaf Yehudai, Dinesh Khandelwal, Dinesh Raghu, Sachindra Joshi

TL;DR

Selective Self-to-Supervised Fine-Tuning (S3FT) tackles the generalization loss seen with standard supervised fine-tuning by leveraging multiple valid outputs and regularizing toward the base model's distribution. It uses a judge to certify when a gold answer is needed and prefers the model's own correct outputs or paraphrased golds as training targets, reducing distribution drift. Empirically, S3FT delivers better in-domain performance than SFT and markedly smaller generalization drops across diverse benchmarks, demonstrating stronger preservation of base capabilities while attaining task gains. The approach is data-efficient, avoids replay buffers, and is applicable to various tasks, though it requires inference over the training set and a reliable equivalence judge, which may limit some open-ended applications.

Abstract

Fine-tuning Large Language Models (LLMs) on specific datasets is a common practice to improve performance on target tasks. However, this performance gain often leads to overfitting, where the model becomes too specialized in either the task or the characteristics of the training data, resulting in a loss of generalization. This paper introduces Selective Self-to-Supervised Fine-Tuning (S3FT), a fine-tuning approach that achieves better performance than the standard supervised fine-tuning (SFT) while improving generalization. S3FT leverages the existence of multiple valid responses to a query. By utilizing the model's correct responses, S3FT reduces model specialization during the fine-tuning stage. S3FT first identifies the correct model responses from the training set by deploying an appropriate judge. Then, it fine-tunes the model using the correct model responses and the gold response (or its paraphrase) for the remaining samples. The effectiveness of S3FT is demonstrated through experiments on mathematical reasoning, Python programming and reading comprehension tasks. The results show that standard SFT can lead to an average performance drop of up to $4.4$ on multiple benchmarks, such as MMLU and TruthfulQA. In contrast, S3FT reduces this drop by half, i.e. $2.5$, indicating better generalization capabilities than SFT while performing significantly better on the fine-tuning tasks.

Selective Self-to-Supervised Fine-Tuning for Generalization in Large Language Models

TL;DR

Selective Self-to-Supervised Fine-Tuning (S3FT) tackles the generalization loss seen with standard supervised fine-tuning by leveraging multiple valid outputs and regularizing toward the base model's distribution. It uses a judge to certify when a gold answer is needed and prefers the model's own correct outputs or paraphrased golds as training targets, reducing distribution drift. Empirically, S3FT delivers better in-domain performance than SFT and markedly smaller generalization drops across diverse benchmarks, demonstrating stronger preservation of base capabilities while attaining task gains. The approach is data-efficient, avoids replay buffers, and is applicable to various tasks, though it requires inference over the training set and a reliable equivalence judge, which may limit some open-ended applications.

Abstract

Fine-tuning Large Language Models (LLMs) on specific datasets is a common practice to improve performance on target tasks. However, this performance gain often leads to overfitting, where the model becomes too specialized in either the task or the characteristics of the training data, resulting in a loss of generalization. This paper introduces Selective Self-to-Supervised Fine-Tuning (S3FT), a fine-tuning approach that achieves better performance than the standard supervised fine-tuning (SFT) while improving generalization. S3FT leverages the existence of multiple valid responses to a query. By utilizing the model's correct responses, S3FT reduces model specialization during the fine-tuning stage. S3FT first identifies the correct model responses from the training set by deploying an appropriate judge. Then, it fine-tunes the model using the correct model responses and the gold response (or its paraphrase) for the remaining samples. The effectiveness of S3FT is demonstrated through experiments on mathematical reasoning, Python programming and reading comprehension tasks. The results show that standard SFT can lead to an average performance drop of up to on multiple benchmarks, such as MMLU and TruthfulQA. In contrast, S3FT reduces this drop by half, i.e. , indicating better generalization capabilities than SFT while performing significantly better on the fine-tuning tasks.

Paper Structure

This paper contains 20 sections, 10 figures, 4 tables, 1 algorithm.

Figures (10)

  • Figure 1: An overview of our proposed approach: Given the input $x_i$ and its corresponding gold response $y_i$, we employ the base model $\mathcal{M}_{\theta}$ to transform $y_i$ such that it is correct but at the same time closer to model's distribution. First, the model predicts the output $\hat{y}$. The judge decides whether the $y_i$ is correct. If true, it becomes part of the training dataset; otherwise, we paraphrase $([x_i;y_i])$ to obtain $\tilde{y_i}$. The judge evaluates the correctness of $\tilde{y_i}$. If true, we use $\tilde{y_i}$; otherwise, we use $y_i$ as the target response. The resulting dataset $\mathcal{D'}$ is used to train the model.
  • Figure 2: Histogram of the log probability assigned by Mistral-7B-Instruct-v0.2 to the gold responses, paraphrase of gold responses and model's own predictions. The distribution is based on 84 examples from the MBPP training data.
  • Figure 3: Prompt used for training the model on GSM8K dataset.
  • Figure 4: Prompt used for predicting the base model's output on GSM8K dataset.
  • Figure 5: Prompt used for predicting the base model's output on NQ dataset.
  • ...and 5 more figures