Domain Generalization with Small Data
Kecheng Chen, Elena Gal, Hong Yan, Haoliang Li
TL;DR
This work tackles domain generalization under small-data conditions in medical imaging by introducing probabilistic embeddings obtained from Bayesian neural networks. It extends the maximum mean discrepancy to a distribution-over-distributions framework ($P$-MMD) and couples it with a probabilistic contrastive semantic alignment ($P$-CSA) to capture both global and local domain alignment. Empirical results across epithelium-stroma classification, skin lesion classification, and spinal cord gray matter segmentation show improved cross-domain performance over state-of-the-art baselines, driven by uncertainty-aware representations and distribution-aware learning. The approach offers a principled, scalable path for robust DG in data-scarce clinical settings, with clear guidance on hyperparameter tradeoffs and potential avenues for future improvements.
Abstract
In this work, we propose to tackle the problem of domain generalization in the context of \textit{insufficient samples}. Instead of extracting latent feature embeddings based on deterministic models, we propose to learn a domain-invariant representation based on the probabilistic framework by mapping each data point into probabilistic embeddings. Specifically, we first extend empirical maximum mean discrepancy (MMD) to a novel probabilistic MMD that can measure the discrepancy between mixture distributions (i.e., source domains) consisting of a series of latent distributions rather than latent points. Moreover, instead of imposing the contrastive semantic alignment (CSA) loss based on pairs of latent points, a novel probabilistic CSA loss encourages positive probabilistic embedding pairs to be closer while pulling other negative ones apart. Benefiting from the learned representation captured by probabilistic models, our proposed method can marriage the measurement on the \textit{distribution over distributions} (i.e., the global perspective alignment) and the distribution-based contrastive semantic alignment (i.e., the local perspective alignment). Extensive experimental results on three challenging medical datasets show the effectiveness of our proposed method in the context of insufficient data compared with state-of-the-art methods.
