Table of Contents
Fetching ...

Class-based Subset Selection for Transfer Learning under Extreme Label Shift

Akul Goyal, Carl Edwards

TL;DR

WaSS tackles extreme label shift in transfer learning by selecting and reweighting a subset of source classes using a Wasserstein-distance-based linear program in the embedding space, enabling effective transfer even when source and target label spaces are disjoint. It then trains a classifier on the reweighted source and fine-tunes on a small target set with a fixed feature extractor, supported by a theoretical bound that links target error to the source error, domain Wasserstein distance, and model drift. The approach yields consistent improvements across multiple datasets and shift settings (including open-set scenarios), and is validated with both quantitative results and qualitative visualizations that illustrate how the source distribution becomes more aligned with the target. Overall, WaSS provides a practical, principled method to mitigate negative transfer by exploiting informative source classes rather than discarding data indiscriminately.

Abstract

Existing work within transfer learning often follows a two-step process -- pre-training over a large-scale source domain and then finetuning over limited samples from the target domain. Yet, despite its popularity, this methodology has been shown to suffer in the presence of distributional shift -- specifically when the output spaces diverge. Previous work has focused on increasing model performance within this setting by identifying and classifying only the shared output classes between distributions. However, these methods are inherently limited as they ignore classes outside the shared class set, disregarding potential information relevant to the model transfer. This paper proposes a new process for few-shot transfer learning that selects and weighs classes from the source domain to optimize the transfer between domains. More concretely, we use Wasserstein distance to choose a set of source classes and their weights that minimize the distance between the source and target domain. To justify our proposed algorithm, we provide a generalization analysis of the performance of the learned classifier over the target domain and show that our method corresponds to a bound minimization algorithm. We empirically demonstrate the effectiveness of our approach (WaSS) by experimenting on several different datasets and presenting superior performance within various label shift settings, including the extreme case where the label spaces are disjoint.

Class-based Subset Selection for Transfer Learning under Extreme Label Shift

TL;DR

WaSS tackles extreme label shift in transfer learning by selecting and reweighting a subset of source classes using a Wasserstein-distance-based linear program in the embedding space, enabling effective transfer even when source and target label spaces are disjoint. It then trains a classifier on the reweighted source and fine-tunes on a small target set with a fixed feature extractor, supported by a theoretical bound that links target error to the source error, domain Wasserstein distance, and model drift. The approach yields consistent improvements across multiple datasets and shift settings (including open-set scenarios), and is validated with both quantitative results and qualitative visualizations that illustrate how the source distribution becomes more aligned with the target. Overall, WaSS provides a practical, principled method to mitigate negative transfer by exploiting informative source classes rather than discarding data indiscriminately.

Abstract

Existing work within transfer learning often follows a two-step process -- pre-training over a large-scale source domain and then finetuning over limited samples from the target domain. Yet, despite its popularity, this methodology has been shown to suffer in the presence of distributional shift -- specifically when the output spaces diverge. Previous work has focused on increasing model performance within this setting by identifying and classifying only the shared output classes between distributions. However, these methods are inherently limited as they ignore classes outside the shared class set, disregarding potential information relevant to the model transfer. This paper proposes a new process for few-shot transfer learning that selects and weighs classes from the source domain to optimize the transfer between domains. More concretely, we use Wasserstein distance to choose a set of source classes and their weights that minimize the distance between the source and target domain. To justify our proposed algorithm, we provide a generalization analysis of the performance of the learned classifier over the target domain and show that our method corresponds to a bound minimization algorithm. We empirically demonstrate the effectiveness of our approach (WaSS) by experimenting on several different datasets and presenting superior performance within various label shift settings, including the extreme case where the label spaces are disjoint.
Paper Structure (24 sections, 8 theorems, 33 equations, 10 figures, 2 tables)

This paper contains 24 sections, 8 theorems, 33 equations, 10 figures, 2 tables.

Key Result

Theorem 2.1

Let ${\mathcal{H}}$ be a hypothesis space where all the classifiers (score functions) are $\rho$-Lipschitz continuous. Then, for every $h\in{\mathcal{H}}$, the following inequality holds where $\lambda^* \vcentcolon= \mathop{\mathrm{arg\,min}}\limits_{h\in{\mathcal{H}}}{\epsilon}_S(h) + {\epsilon}_T(h)$ is the optimal joint 0-1 error that a single classifier could obtain over both domains.

Figures (10)

  • Figure 1:
  • Figure 2: Accuracy of the downstream classifier on the target domain weighted by class selection methods: OSS and ALL within an open-set setting. TV Distance of two different class selection methods based on overlapping classes between domains.
  • Figure 3: Class distributions before (inner circle) and after (outer circle) applying our method to arbitrarily selected test classes. The same color corresponds to the same class (e.g. Deer changes from 14% to 20%). Test Classes for F-MNIST are $\{\text{T-shirt/top, Trouser, Pullover}\}$ and for Cifar-10 are $\{\text{Airplane, Automobile, Bird}\}$.
  • Figure 4: Class distributions of test set (inner circle) and class distribution of training set (outer circle) weighted by class selection methods: WaSS and OSS. The same color corresponds to the same class. Test Class: $\{\text{Ankle boot, T-shirt/top, Trousers}\}$
  • Figure 5: Accuracy of the downstream classifier on the target domain weighted by class selection methods: WaSS and ALL under two different embedding dimensions.
  • ...and 5 more figures

Theorems & Definitions (10)

  • Definition 2.1: Wasserstein Distance
  • Theorem 2.1: Theorem 1 shen2018wasserstein
  • Definition 3.1: Induced classifier
  • Proposition 3.1
  • Proposition 3.2
  • Lemma 3.1
  • Theorem 3.1
  • Lemma 3.2
  • Proposition F.1
  • Proposition F.2