Table of Contents
Fetching ...

Unsupervised optimal deep transfer learning for classification under general conditional shift

Junjun Lang, Yukun Liu

TL;DR

The paper tackles classification under distribution shift with unlabeled target data by introducing General Conditional Shift (GCS), which generalizes label shift and enables identifiability of the target distribution and shift function. It combines a deep neural network-based estimator for the source conditional probabilities with a pseudo-maximum-likelihood estimator for the target class proportions, and constructs a Bayes classifier that cancels the unknown shift function $h(\cdot)$. Theoretical results establish concentration bounds for the estimators and prove minimax optimality of the proposed classifier up to logarithmic factors, with rates governed by the intrinsic dimension of the source classification function. Empirically, the method demonstrates strong performance in simulations and on an Alzheimer's disease dataset, illustrating robustness to high dimensionality and effectiveness when GCS holds, even when label shift fails. The work provides a principled, scalable approach for unsupervised transfer learning in nonparametric multi-class classification with practical impact for biomedical and other real-world domains.

Abstract

Classifiers trained solely on labeled source data may yield misleading results when applied to unlabeled target data drawn from a different distribution. Transfer learning can rectify this by transferring knowledge from source to target data, but its effectiveness frequently relies on stringent assumptions, such as label shift. In this paper, we introduce a novel General Conditional Shift (GCS) assumption, which encompasses label shift as a special scenario. Under GCS, we demonstrate that both the target distribution and the shift function are identifiable. To estimate the conditional probabilities ${\bmη}_P$ for source data, we propose leveraging deep neural networks (DNNs). Subsequent to transferring the DNN estimator, we estimate the target label distribution ${\bmπ}_Q$ utilizing a pseudo-maximum likelihood approach. Ultimately, by incorporating these estimates and circumventing the need to estimate the shift function, we construct our proposed Bayes classifier. We establish concentration bounds for our estimators of both ${\bmη}_P$ and ${\bmπ}_Q$ in terms of the intrinsic dimension of ${\bmη}_P$ . Notably, our DNN-based classifier achieves the optimal minimax rate, up to a logarithmic factor. A key advantage of our method is its capacity to effectively combat the curse of dimensionality when ${\bmη}_P$ exhibits a low-dimensional structure. Numerical simulations, along with an analysis of an Alzheimer's disease dataset, underscore its exceptional performance.

Unsupervised optimal deep transfer learning for classification under general conditional shift

TL;DR

The paper tackles classification under distribution shift with unlabeled target data by introducing General Conditional Shift (GCS), which generalizes label shift and enables identifiability of the target distribution and shift function. It combines a deep neural network-based estimator for the source conditional probabilities with a pseudo-maximum-likelihood estimator for the target class proportions, and constructs a Bayes classifier that cancels the unknown shift function . Theoretical results establish concentration bounds for the estimators and prove minimax optimality of the proposed classifier up to logarithmic factors, with rates governed by the intrinsic dimension of the source classification function. Empirically, the method demonstrates strong performance in simulations and on an Alzheimer's disease dataset, illustrating robustness to high dimensionality and effectiveness when GCS holds, even when label shift fails. The work provides a principled, scalable approach for unsupervised transfer learning in nonparametric multi-class classification with practical impact for biomedical and other real-world domains.

Abstract

Classifiers trained solely on labeled source data may yield misleading results when applied to unlabeled target data drawn from a different distribution. Transfer learning can rectify this by transferring knowledge from source to target data, but its effectiveness frequently relies on stringent assumptions, such as label shift. In this paper, we introduce a novel General Conditional Shift (GCS) assumption, which encompasses label shift as a special scenario. Under GCS, we demonstrate that both the target distribution and the shift function are identifiable. To estimate the conditional probabilities for source data, we propose leveraging deep neural networks (DNNs). Subsequent to transferring the DNN estimator, we estimate the target label distribution utilizing a pseudo-maximum likelihood approach. Ultimately, by incorporating these estimates and circumventing the need to estimate the shift function, we construct our proposed Bayes classifier. We establish concentration bounds for our estimators of both and in terms of the intrinsic dimension of . Notably, our DNN-based classifier achieves the optimal minimax rate, up to a logarithmic factor. A key advantage of our method is its capacity to effectively combat the curse of dimensionality when exhibits a low-dimensional structure. Numerical simulations, along with an analysis of an Alzheimer's disease dataset, underscore its exceptional performance.

Paper Structure

This paper contains 22 sections, 6 theorems, 24 equations, 5 figures.

Key Result

Lemma 1

Under Assumptions assumption-GCS and assumption-identifiability, the target distribution $Q_{(X, Y)}$ and the function $h(\cdot)$ in model GCS are identifiable from the data in data.

Figures (5)

  • Figure 1: A three layers deep neural network with $K = 2$ and ${\bf p} = (4,3,3,1)$.
  • Figure 2: Plots of $n_P$ versus excess-risks of the classifiers under comparison when $n_Q = 400$. Upper panel: the balanced case ($\pi_{Q,1} = 0.5$); Lower panel: the unbalanced case ($\pi_{Q,1} = 0.25$)
  • Figure 3: Mean square errors of the three estimators of $\pi_{Q,1}$ under comparison with $n_Q = 400$. Upper panel: the balanced case ($\pi_{Q,1} = 0.5$); Lower panel: the unbalanced case ($\pi_{Q,1} = 0.25$)
  • Figure 4: Histograms of estimated density ratios $r_1(x)$ and $r_2(x)$ (plot (a)), and their difference $r_1(x)-r_2(x)$ (plot (b)) evaluated at the covariates of the test data.
  • Figure 5: Plots of analysis results for the Alzheimer's Disease Dataset when $p=0.5$, $0.6$, or $0.7$. Left plot: Correct classification proportions of the five classifiers under comparison. Right plot: Absolute relative bias of DNN, Kernel and KNN with respect to the proportion of label $1$ in the target data.

Theorems & Definitions (11)

  • Lemma 1
  • Definition 1: H$\rm \ddot{o}$lder classes
  • Definition 2: Composite smoothness function class Schmidt-Hieber2020
  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Example 1: Generalized additive model, GAM generalizedadditive2007
  • Example 2: Generalized hierarchical interaction model kohler2016nonparametric
  • Theorem 4
  • Corollary 1
  • ...and 1 more