Table of Contents
Fetching ...

Domain Generalization with Small Data

Kecheng Chen, Elena Gal, Hong Yan, Haoliang Li

TL;DR

This work tackles domain generalization under small-data conditions in medical imaging by introducing probabilistic embeddings obtained from Bayesian neural networks. It extends the maximum mean discrepancy to a distribution-over-distributions framework ($P$-MMD) and couples it with a probabilistic contrastive semantic alignment ($P$-CSA) to capture both global and local domain alignment. Empirical results across epithelium-stroma classification, skin lesion classification, and spinal cord gray matter segmentation show improved cross-domain performance over state-of-the-art baselines, driven by uncertainty-aware representations and distribution-aware learning. The approach offers a principled, scalable path for robust DG in data-scarce clinical settings, with clear guidance on hyperparameter tradeoffs and potential avenues for future improvements.

Abstract

In this work, we propose to tackle the problem of domain generalization in the context of \textit{insufficient samples}. Instead of extracting latent feature embeddings based on deterministic models, we propose to learn a domain-invariant representation based on the probabilistic framework by mapping each data point into probabilistic embeddings. Specifically, we first extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD that can measure the discrepancy between mixture distributions (i.e., source domains) consisting of a series of latent distributions rather than latent points. Moreover, instead of imposing the contrastive semantic alignment (CSA) loss based on pairs of latent points, a novel probabilistic CSA loss encourages positive probabilistic embedding pairs to be closer while pulling other negative ones apart. Benefiting from the learned representation captured by probabilistic models, our proposed method can marriage the measurement on the \textit{distribution over distributions} (i.e., the global perspective alignment) and the distribution-based contrastive semantic alignment (i.e., the local perspective alignment). Extensive experimental results on three challenging medical datasets show the effectiveness of our proposed method in the context of insufficient data compared with state-of-the-art methods.

Domain Generalization with Small Data

TL;DR

This work tackles domain generalization under small-data conditions in medical imaging by introducing probabilistic embeddings obtained from Bayesian neural networks. It extends the maximum mean discrepancy to a distribution-over-distributions framework (-MMD) and couples it with a probabilistic contrastive semantic alignment (-CSA) to capture both global and local domain alignment. Empirical results across epithelium-stroma classification, skin lesion classification, and spinal cord gray matter segmentation show improved cross-domain performance over state-of-the-art baselines, driven by uncertainty-aware representations and distribution-aware learning. The approach offers a principled, scalable path for robust DG in data-scarce clinical settings, with clear guidance on hyperparameter tradeoffs and potential avenues for future improvements.

Abstract

In this work, we propose to tackle the problem of domain generalization in the context of \textit{insufficient samples}. Instead of extracting latent feature embeddings based on deterministic models, we propose to learn a domain-invariant representation based on the probabilistic framework by mapping each data point into probabilistic embeddings. Specifically, we first extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD that can measure the discrepancy between mixture distributions (i.e., source domains) consisting of a series of latent distributions rather than latent points. Moreover, instead of imposing the contrastive semantic alignment (CSA) loss based on pairs of latent points, a novel probabilistic CSA loss encourages positive probabilistic embedding pairs to be closer while pulling other negative ones apart. Benefiting from the learned representation captured by probabilistic models, our proposed method can marriage the measurement on the \textit{distribution over distributions} (i.e., the global perspective alignment) and the distribution-based contrastive semantic alignment (i.e., the local perspective alignment). Extensive experimental results on three challenging medical datasets show the effectiveness of our proposed method in the context of insufficient data compared with state-of-the-art methods.
Paper Structure (22 sections, 1 theorem, 17 equations, 5 figures, 10 tables)

This paper contains 22 sections, 1 theorem, 17 equations, 5 figures, 10 tables.

Key Result

Theorem 1

Let $\mathbb{P}_1,\ldots,\mathbb{P}_N$ be probability distributions and $\hat{\mathbb{P}}:=\frac{1}{N}\sum_{i=1}^{N}\mathbb{P}_i$. Then the distributional variance given by $\frac{1}{N}\sum\lVert\mu_{\mathbb{P}_i}-\mu_{\hat{\mathbb{P}}}\rVert$ is 0 iff $\mathbb{P}_1=\mathbb{P}_2=\ldots=\mathbb{P}_N$

Figures (5)

  • Figure 1: Histopathological image examples of breast cancer tissue from three different healthcare institutes, including NKI with 626 images, IHC with 645 images, and VGH with 1324 images. There are two different tissue types, including epithelium and stroma. Obvious domain gaps (e.g., the density of tissue and the staining color) can be observed.
  • Figure 2: A visualized computational process for probabilistic MMD (P-MMD) on two source domains. The same color for samples in different domains denotes the same label.
  • Figure 3: The loss curve of iteration on skin lesion and epothelial-stromal classficaiton tasks. (a) Global alignment loss (b) Local alignment loss.
  • Figure 4: The performance comparison between mean embedding method and kernel mean embedding method with different Monte Carlo samples $T$. For each sub-figure, we use only one alignment operation. (a) Local alignment. Mean Embedding: The mean embedding operation with Euclidean distance is utilized between probabilistic embedding pairs. Kernel Mean Embedding: The kernel mean embedding with MMD distance is utilized between probabilistic embedding pairs. (b) Global alignment. Mean Embedding: The mean embedding operation with MMD distance is utilized between domains (as distributions). Kernel Mean Embedding: The kernel mean embedding with P-MMD distance is utilized between domains (as distributions over distributions).
  • Figure 5: The performance of our proposed model on the NKI task of Epithelium Stroma classification with different Monte Carlo samples $T$.

Theorems & Definitions (2)

  • Theorem 1: muandet2012learning
  • Remark 1