Table of Contents
Fetching ...

Wasserstein Transfer Learning

Kaicheng Zhang, Sinian Zhang, Doudou Zhou, Yidong Zhou

TL;DR

A novel transfer learning framework for regression models whose outputs are probability distributions residing in the Wasserstein space is introduced, and an estimator with provable asymptotic convergence rates is proposed, quantifying the impact of domain similarity on transfer efficiency.

Abstract

Transfer learning is a powerful paradigm for leveraging knowledge from source domains to enhance learning in a target domain. However, traditional transfer learning approaches often focus on scalar or multivariate data within Euclidean spaces, limiting their applicability to complex data structures such as probability distributions. To address this limitation, we introduce a novel transfer learning framework for regression models whose outputs are probability distributions residing in the Wasserstein space. When the informative subset of transferable source domains is known, we propose an estimator with provable asymptotic convergence rates, quantifying the impact of domain similarity on transfer efficiency. For cases where the informative subset is unknown, we develop a data-driven transfer learning procedure designed to mitigate negative transfer. The proposed methods are supported by rigorous theoretical analysis and are validated through extensive simulations and real-world applications. The code is available at https://github.com/h7nian/WaTL

Wasserstein Transfer Learning

TL;DR

A novel transfer learning framework for regression models whose outputs are probability distributions residing in the Wasserstein space is introduced, and an estimator with provable asymptotic convergence rates is proposed, quantifying the impact of domain similarity on transfer efficiency.

Abstract

Transfer learning is a powerful paradigm for leveraging knowledge from source domains to enhance learning in a target domain. However, traditional transfer learning approaches often focus on scalar or multivariate data within Euclidean spaces, limiting their applicability to complex data structures such as probability distributions. To address this limitation, we introduce a novel transfer learning framework for regression models whose outputs are probability distributions residing in the Wasserstein space. When the informative subset of transferable source domains is known, we propose an estimator with provable asymptotic convergence rates, quantifying the impact of domain similarity on transfer efficiency. For cases where the informative subset is unknown, we develop a data-driven transfer learning procedure designed to mitigate negative transfer. The proposed methods are supported by rigorous theoretical analysis and are validated through extensive simulations and real-world applications. The code is available at https://github.com/h7nian/WaTL

Paper Structure

This paper contains 27 sections, 10 theorems, 102 equations, 3 figures, 1 table, 6 algorithms.

Key Result

Lemma 1

Let $\widehat{f}^{(k)}(x)=n_k^{-1}\sum_{i=1}^{n_k}s_{iG}^{(k)}(x)F_{\nu_i^{(k)}}^{-1}$ and its population counterpart be defined as $f^{(k)}(x)=E\{s_{G}^{(k)}(x)F_{\nu^{(k)}}^{-1}\}$ for $k=0, \ldots, K$. Then under Condition con:1, $\|\widehat{f}^{(k)}(x)-f^{(k)}(x)\|_2=O_p(n_k^{-1/2})$.

Figures (3)

  • Figure 1: (a) Root mean squared prediction risk (RMSPR) of WaTL, only Source, and Only Target methods under varying target sample sizes, with source sample sizes $\tau = 100$ (left) and $\tau = 200$ (right); (b) Selection rate of each source site as $\psi$ increases.
  • Figure 2: (a) Root mean squared prediction risk (RMSPR) of WaTL and Only Target methods for females and males, evaluated using five-fold cross-validation; (b) Cumulative distribution function of physical activity levels for one selected female (left) and one selected male (right), along with estimates from WaTL and Only Target methods.
  • Figure 3: (a) Age-at-death densities of developed and developing countries; (b) Root mean squared prediction risk (RMSPR) of WaTL and Only Target methods for human mortality data.

Theorems & Definitions (24)

  • Lemma 1
  • Remark 1
  • Theorem 1
  • Theorem 2
  • Remark 2
  • Theorem 3
  • Lemma 2
  • proof
  • proof
  • Lemma 3
  • ...and 14 more