Unsupervised optimal deep transfer learning for classification under general conditional shift
Junjun Lang, Yukun Liu
TL;DR
The paper tackles classification under distribution shift with unlabeled target data by introducing General Conditional Shift (GCS), which generalizes label shift and enables identifiability of the target distribution and shift function. It combines a deep neural network-based estimator for the source conditional probabilities with a pseudo-maximum-likelihood estimator for the target class proportions, and constructs a Bayes classifier that cancels the unknown shift function $h(\cdot)$. Theoretical results establish concentration bounds for the estimators and prove minimax optimality of the proposed classifier up to logarithmic factors, with rates governed by the intrinsic dimension of the source classification function. Empirically, the method demonstrates strong performance in simulations and on an Alzheimer's disease dataset, illustrating robustness to high dimensionality and effectiveness when GCS holds, even when label shift fails. The work provides a principled, scalable approach for unsupervised transfer learning in nonparametric multi-class classification with practical impact for biomedical and other real-world domains.
Abstract
Classifiers trained solely on labeled source data may yield misleading results when applied to unlabeled target data drawn from a different distribution. Transfer learning can rectify this by transferring knowledge from source to target data, but its effectiveness frequently relies on stringent assumptions, such as label shift. In this paper, we introduce a novel General Conditional Shift (GCS) assumption, which encompasses label shift as a special scenario. Under GCS, we demonstrate that both the target distribution and the shift function are identifiable. To estimate the conditional probabilities ${\bmη}_P$ for source data, we propose leveraging deep neural networks (DNNs). Subsequent to transferring the DNN estimator, we estimate the target label distribution ${\bmπ}_Q$ utilizing a pseudo-maximum likelihood approach. Ultimately, by incorporating these estimates and circumventing the need to estimate the shift function, we construct our proposed Bayes classifier. We establish concentration bounds for our estimators of both ${\bmη}_P$ and ${\bmπ}_Q$ in terms of the intrinsic dimension of ${\bmη}_P$ . Notably, our DNN-based classifier achieves the optimal minimax rate, up to a logarithmic factor. A key advantage of our method is its capacity to effectively combat the curse of dimensionality when ${\bmη}_P$ exhibits a low-dimensional structure. Numerical simulations, along with an analysis of an Alzheimer's disease dataset, underscore its exceptional performance.
