Table of Contents
Fetching ...

Domain Adaptation and Entanglement: an Optimal Transport Perspective

Okan Koç, Alexander Soen, Chao-Kai Chiang, Masashi Sugiyama

TL;DR

The paper tackles robustness to distribution shifts by formulating unsupervised domain adaptation (UDA) bounds through optimal transport. It introduces entanglement, an unoptimizable component capturing how aligning marginals can degrade target accuracy via misaligned conditionals. The authors develop a theoretical OT-based framework, define label- and prediction-entanglement, and derive an Oracle Bound that connects source risk, marginal alignment, and entanglement. Empirically, entanglement explains why some domain-matching methods fail under distribution shifts and how assumptions like Close Conditionals and Gradual Shift can improve transfer, with practical implications for choosing loss functions, models, and optimization strategies in UDA tasks.

Abstract

Current machine learning systems are brittle in the face of distribution shifts (DS), where the target distribution that the system is tested on differs from the source distribution used to train the system. This problem of robustness to DS has been studied extensively in the field of domain adaptation. For deep neural networks, a popular framework for unsupervised domain adaptation (UDA) is domain matching, in which algorithms try to align the marginal distributions in the feature or output space. The current theoretical understanding of these methods, however, is limited and existing theoretical results are not precise enough to characterize their performance in practice. In this paper, we derive new bounds based on optimal transport that analyze the UDA problem. Our new bounds include a term which we dub as \emph{entanglement}, consisting of an expectation of Wasserstein distance between conditionals with respect to changing data distributions. Analysis of the entanglement term provides a novel perspective on the unoptimizable aspects of UDA. In various experiments with multiple models across several DS scenarios, we show that this term can be used to explain the varying performance of UDA algorithms.

Domain Adaptation and Entanglement: an Optimal Transport Perspective

TL;DR

The paper tackles robustness to distribution shifts by formulating unsupervised domain adaptation (UDA) bounds through optimal transport. It introduces entanglement, an unoptimizable component capturing how aligning marginals can degrade target accuracy via misaligned conditionals. The authors develop a theoretical OT-based framework, define label- and prediction-entanglement, and derive an Oracle Bound that connects source risk, marginal alignment, and entanglement. Empirically, entanglement explains why some domain-matching methods fail under distribution shifts and how assumptions like Close Conditionals and Gradual Shift can improve transfer, with practical implications for choosing loss functions, models, and optimization strategies in UDA tasks.

Abstract

Current machine learning systems are brittle in the face of distribution shifts (DS), where the target distribution that the system is tested on differs from the source distribution used to train the system. This problem of robustness to DS has been studied extensively in the field of domain adaptation. For deep neural networks, a popular framework for unsupervised domain adaptation (UDA) is domain matching, in which algorithms try to align the marginal distributions in the feature or output space. The current theoretical understanding of these methods, however, is limited and existing theoretical results are not precise enough to characterize their performance in practice. In this paper, we derive new bounds based on optimal transport that analyze the UDA problem. Our new bounds include a term which we dub as \emph{entanglement}, consisting of an expectation of Wasserstein distance between conditionals with respect to changing data distributions. Analysis of the entanglement term provides a novel perspective on the unoptimizable aspects of UDA. In various experiments with multiple models across several DS scenarios, we show that this term can be used to explain the varying performance of UDA algorithms.

Paper Structure

This paper contains 56 sections, 22 theorems, 77 equations, 5 figures, 7 tables.

Key Result

Lemma 3.1

Suppose the loss function $\ell \colon \mathcal{Y} \times \mathcal{Y} \rightarrow \mathbb{R}$ satisfies Assumption assum:metric. Then the target risk of a classifier $f \colon \mathcal{X} \rightarrow \mathcal{Y}$ is bounded by where the inequality $\rm (s)$ holds for surjective $f$. Additionally, $\rm (s)$ is an equality whenever $f$ is invertible. We remind the readers that $(f\sharp p)(\hat{y},

Figures (5)

  • Figure 1: We derive new bounds for unsupervised domain adaptation based on optimal transport and introduce a term called entanglement, which quantifies the loss in accuracy that can happen when aligning marginals. During the marginal alignment step, optimal transport associates source inputs (gray) to target inputs (brown) and stochastic gradient descent tries to find neural network parameters that pull these coupled points closer together. Entanglement measures the average loss of associating pairs with different labels. In this figure we visualize such an entangled (source, target) output pair using the Portraits dataset. Shaded circles correspond to images labeled female and empty circles to males. The dotted line indicates the decision boundary in the output space of a convolutional neural network separating male predictions from females.
  • Figure 2: Target accuracy plots for the MNIST $\to$ USPS scenario shown in Table \ref{['table1']} using the model MLP. Corresponding entanglement estimates are shown on the right hand figure.
  • Figure 5: Target loss over the epochs for the MNIST $\to$ MNIST-M scenario shown in Table \ref{['table2']} using the model conv1. Corresponding entanglement estimates as well as the WRR values are included for comparison.
  • Figure 8: Target accuracy plots for the USPS $\to$ MNIST scenario shown in Table \ref{['table1']} using the model conv1. Left/right hand sides correspond to batch sizes $64$ and $512$ respectively.
  • Figure 10: A case in the USPS $\to$ MNIST scenario for the conv1 model where increasing the number of epochs worsens the accuracy of the methods. All hyperparameter settings and model configurations are the same as in Table \ref{['table1']} except for the number of epochs.

Theorems & Definitions (57)

  • Definition 3.1
  • Definition 3.2
  • Lemma 3.1
  • proof : Proof Sketch
  • Lemma 3.2
  • proof : Proof Sketch
  • Definition 3.3
  • Corollary 3.1
  • Definition 3.4
  • Theorem 3.1
  • ...and 47 more