Table of Contents
Fetching ...

FisherSFT: Data-Efficient Supervised Fine-Tuning of Language Models Using Information Gain

Rohan Deb, Kiran Thekumparampil, Kousha Kalantari, Gaurush Hiranandani, Shoham Sabach, Branislav Kveton

TL;DR

FisherSFT tackles the data efficiency challenge in supervised fine-tuning of large language models by selecting the most informative training sentences. It reframes sentence selection as a greedy optimal design problem that maximizes a lower bound on the log-determinant of the Hessian of the log-likelihood, enabling efficient estimation using a last-layer linearization with multinomial logistic regression. The paper provides a theoretical error bound of order $O(1/\sqrt{n})$ under mild assumptions and introduces a fast Greedy design variant with cached gains for scalability. Empirically, FisherSFT outperforms baselines on synthetic tasks, word embeddings, and GPT-2 Shakespeare data, and yields higher-quality text according to LLM-based evaluation, highlighting its practical impact for data-efficient language model adaptation.

Abstract

Supervised fine-tuning (SFT) is a standard approach to adapting large language models (LLMs) to new domains. In this work, we improve the statistical efficiency of SFT by selecting an informative subset of training examples. Specifically, for a fixed budget of training examples, which determines the computational cost of fine-tuning, we determine the most informative ones. The key idea in our method is to select examples that maximize information gain, measured by the Hessian of the log-likelihood of the LLM. We approximate it efficiently by linearizing the LLM at the last layer using multinomial logistic regression models. Our approach is computationally efficient, analyzable, and performs well empirically. We demonstrate this on several problems, and back our claims with both quantitative results and an LLM evaluation.

FisherSFT: Data-Efficient Supervised Fine-Tuning of Language Models Using Information Gain

TL;DR

FisherSFT tackles the data efficiency challenge in supervised fine-tuning of large language models by selecting the most informative training sentences. It reframes sentence selection as a greedy optimal design problem that maximizes a lower bound on the log-determinant of the Hessian of the log-likelihood, enabling efficient estimation using a last-layer linearization with multinomial logistic regression. The paper provides a theoretical error bound of order under mild assumptions and introduces a fast Greedy design variant with cached gains for scalability. Empirically, FisherSFT outperforms baselines on synthetic tasks, word embeddings, and GPT-2 Shakespeare data, and yields higher-quality text according to LLM-based evaluation, highlighting its practical impact for data-efficient language model adaptation.

Abstract

Supervised fine-tuning (SFT) is a standard approach to adapting large language models (LLMs) to new domains. In this work, we improve the statistical efficiency of SFT by selecting an informative subset of training examples. Specifically, for a fixed budget of training examples, which determines the computational cost of fine-tuning, we determine the most informative ones. The key idea in our method is to select examples that maximize information gain, measured by the Hessian of the log-likelihood of the LLM. We approximate it efficiently by linearizing the LLM at the last layer using multinomial logistic regression models. Our approach is computationally efficient, analyzable, and performs well empirically. We demonstrate this on several problems, and back our claims with both quantitative results and an LLM evaluation.

Paper Structure

This paper contains 16 sections, 11 theorems, 76 equations, 3 figures, 1 table, 4 algorithms.

Key Result

Lemma 3.1

Consider the loss function described in eq:loglik_subset. Then the Hessian of the loss is given by where $\otimes$ is the tensor product. Moreover, if holds for some $\gamma > 0$, then

Figures (3)

  • Figure 1: Comparison of maximum and mean prediction errors on synthetic token vectors. The x axis shows the number of sentences selected to train the model. The y axis shows the corresponding error averaged over $20$ runs.
  • Figure 2: Comparison of maximum and mean prediction errors on word2vec token vectors. The x axis shows the number of sentences selected to train the model. The y axis shows the corresponding error averaged over $20$ runs.
  • Figure 3: Text generated by fine-tuned GPT-2 models on sentences selected by $\color{Green}\tt Uniform$ and $\color{Green}\tt FisherSFT$. The latter is more coherent.

Theorems & Definitions (18)

  • Lemma 3.1
  • proof
  • Theorem 4.3
  • Lemma 4.3
  • Lemma 4.3
  • Lemma 4.3
  • Proposition 2.1
  • proof
  • Lemma 3.0
  • proof
  • ...and 8 more