Table of Contents
Fetching ...

Batch Normalization Embeddings for Deep Domain Generalization

Mattia Segu, Alessio Tonioni, Federico Tombari

TL;DR

This work introduces Batch Normalization Embeddings (BNE) to address domain generalization by learning domain-specific BN statistics for multiple source domains, forming a latent domain space. Unknown test domains are projected into this space via instance statistics, and their similarity to known domains determines a weighted ensemble of domain-specific classifiers, enabling robust predictions without target-domain supervision. Across PACS, Office-31, and Office-Caltech, BNE yields significant accuracy gains over strong baselines and many state-of-the-art methods, particularly on challenging domains. The approach leverages a principled Wasserstein-based distance on Gaussian BN statistics and demonstrates that maintaining domain-specific representations can outperform forcing invariant features, with potential extensions to domain adaptation scenarios.

Abstract

Domain generalization aims at training machine learning models to perform robustly across different and unseen domains. Several recent methods use multiple datasets to train models to extract domain-invariant features, hoping to generalize to unseen domains. Instead, first we explicitly train domain-dependant representations by using ad-hoc batch normalization layers to collect independent domain's statistics. Then, we propose to use these statistics to map domains in a shared latent space, where membership to a domain can be measured by means of a distance function. At test time, we project samples from an unknown domain into the same space and infer properties of their domain as a linear combination of the known ones. We apply the same mapping strategy at training and test time, learning both a latent representation and a powerful but lightweight ensemble model. We show a significant increase in classification accuracy over current state-of-the-art techniques on popular domain generalization benchmarks: PACS, Office-31 and Office-Caltech.

Batch Normalization Embeddings for Deep Domain Generalization

TL;DR

This work introduces Batch Normalization Embeddings (BNE) to address domain generalization by learning domain-specific BN statistics for multiple source domains, forming a latent domain space. Unknown test domains are projected into this space via instance statistics, and their similarity to known domains determines a weighted ensemble of domain-specific classifiers, enabling robust predictions without target-domain supervision. Across PACS, Office-31, and Office-Caltech, BNE yields significant accuracy gains over strong baselines and many state-of-the-art methods, particularly on challenging domains. The approach leverages a principled Wasserstein-based distance on Gaussian BN statistics and demonstrates that maintaining domain-specific representations can outperform forcing invariant features, with potential extensions to domain adaptation scenarios.

Abstract

Domain generalization aims at training machine learning models to perform robustly across different and unseen domains. Several recent methods use multiple datasets to train models to extract domain-invariant features, hoping to generalize to unseen domains. Instead, first we explicitly train domain-dependant representations by using ad-hoc batch normalization layers to collect independent domain's statistics. Then, we propose to use these statistics to map domains in a shared latent space, where membership to a domain can be measured by means of a distance function. At test time, we project samples from an unknown domain into the same space and infer properties of their domain as a linear combination of the known ones. We apply the same mapping strategy at training and test time, learning both a latent representation and a powerful but lightweight ensemble model. We show a significant increase in classification accuracy over current state-of-the-art techniques on popular domain generalization benchmarks: PACS, Office-31 and Office-Caltech.

Paper Structure

This paper contains 26 sections, 9 equations, 2 figures, 14 tables.

Figures (2)

  • Figure 1: Visualization of our method on the PACS dataset when the domains Art Painting, Photo and Carton are available at training time. We propose to use batch normalization layers to implicitly learn a domain space onto which map both known (training) and unknown (testing) domains. At test time, we project each target sample independently in the domain space and locate it with respect to the known domains using the corresponding distances $D_{a,t}$, $D_{p,t}$, and $D_{c,t}$. Properties of the unknown domain are revealed by the location of the unseen sample. We leverage these hints to improve classification of each test sample by means of a linear combination of domain specific classifiers, weighted by the inverse of the distances.
  • Figure 2: Our method on PACS (li2017deeper) with Sketch as unknown domain. A Multi-Source Domain Alignment Layer (a) collects domain-specific population statistics and compute instance statistics for test samples. After training, the population and instance statistics map respectively the source domains and the test samples into a latent space, where domain similarity can be measured by distances between embedding vectors. In (b), we visualize the learned domain space $\mathcal{L}^1$ by means of a t-SNE plot of instance normalization and population statistics for a model trained with our method. Each test sample from the unseen domain sketch can be localized through its instance statistics (cyan dots) with respect to the known domains, embedded by the population statistics (green dots). Considering a test sample embedding, e.g.$r_t$, the estimated distances (orange arrows) will be used to weigh the predictions of domain-specific classifiers.