Table of Contents
Fetching ...

Generative Classifier for Domain Generalization

Shaocong Long, Qianyu Zhou, Xiangtai Li, Chenhao Ying, Yunhai Tong, Lizhuang Ma, Yuan Luo, Dacheng Tao

TL;DR

This work tackles domain generalization by challenging the default DG practice of strict domain invariance with a generative classifier that can model multi-modal, domain-specific feature distributions. The proposed GCDG framework replaces the discriminative linear classifier with a Gaussian Mixture Model-based HLC, augmented by Spurious Correlation Blocking and Diverse Component Balancing to capture beneficial domain-specific information while mitigating spurious patterns. The authors provide theoretical results showing that enforcing invariance can raise the target risk bound, and that relaxing invariance via a generative approach can reduce this bound and promote flat minima. Empirically, GCDG achieves competitive or state-of-the-art performance across five DG benchmarks and a face anti-spoofing dataset, and can be integrated as a plug-in with existing DG methods. The work offers a new direction for DG by leveraging domain-specific information through a principled generative classifier, with practical implications for robust cross-domain understanding.

Abstract

Domain generalization (DG) aims to improve the generalizability of computer vision models toward distribution shifts. The mainstream DG methods focus on learning domain invariance, however, such methods overlook the potential inherent in domain-specific information. While the prevailing practice of discriminative linear classifier has been tailored to domain-invariant features, it struggles when confronted with diverse domain-specific information, e.g., intra-class shifts, that exhibits multi-modality. To address these issues, we explore the theoretical implications of relying on domain invariance, revealing the crucial role of domain-specific information in mitigating the target risk for DG. Drawing from these insights, we propose Generative Classifier-driven Domain Generalization (GCDG), introducing a generative paradigm for the DG classifier based on Gaussian Mixture Models (GMMs) for each class across domains. GCDG consists of three key modules: Heterogeneity Learning Classifier~(HLC), Spurious Correlation Blocking~(SCB), and Diverse Component Balancing~(DCB). Concretely, HLC attempts to model the feature distributions and thereby capture valuable domain-specific information via GMMs. SCB identifies the neural units containing spurious correlations and perturbs them, mitigating the risk of HLC learning spurious patterns. Meanwhile, DCB ensures a balanced contribution of components in HLC, preventing the underestimation or neglect of critical components. In this way, GCDG excels in capturing the nuances of domain-specific information characterized by diverse distributions. GCDG demonstrates the potential to reduce the target risk and encourage flat minima, improving the generalizability. Extensive experiments show GCDG's comparable performance on five DG benchmarks and one face anti-spoofing dataset, seamlessly integrating into existing DG methods with consistent improvements.

Generative Classifier for Domain Generalization

TL;DR

This work tackles domain generalization by challenging the default DG practice of strict domain invariance with a generative classifier that can model multi-modal, domain-specific feature distributions. The proposed GCDG framework replaces the discriminative linear classifier with a Gaussian Mixture Model-based HLC, augmented by Spurious Correlation Blocking and Diverse Component Balancing to capture beneficial domain-specific information while mitigating spurious patterns. The authors provide theoretical results showing that enforcing invariance can raise the target risk bound, and that relaxing invariance via a generative approach can reduce this bound and promote flat minima. Empirically, GCDG achieves competitive or state-of-the-art performance across five DG benchmarks and a face anti-spoofing dataset, and can be integrated as a plug-in with existing DG methods. The work offers a new direction for DG by leveraging domain-specific information through a principled generative classifier, with practical implications for robust cross-domain understanding.

Abstract

Domain generalization (DG) aims to improve the generalizability of computer vision models toward distribution shifts. The mainstream DG methods focus on learning domain invariance, however, such methods overlook the potential inherent in domain-specific information. While the prevailing practice of discriminative linear classifier has been tailored to domain-invariant features, it struggles when confronted with diverse domain-specific information, e.g., intra-class shifts, that exhibits multi-modality. To address these issues, we explore the theoretical implications of relying on domain invariance, revealing the crucial role of domain-specific information in mitigating the target risk for DG. Drawing from these insights, we propose Generative Classifier-driven Domain Generalization (GCDG), introducing a generative paradigm for the DG classifier based on Gaussian Mixture Models (GMMs) for each class across domains. GCDG consists of three key modules: Heterogeneity Learning Classifier~(HLC), Spurious Correlation Blocking~(SCB), and Diverse Component Balancing~(DCB). Concretely, HLC attempts to model the feature distributions and thereby capture valuable domain-specific information via GMMs. SCB identifies the neural units containing spurious correlations and perturbs them, mitigating the risk of HLC learning spurious patterns. Meanwhile, DCB ensures a balanced contribution of components in HLC, preventing the underestimation or neglect of critical components. In this way, GCDG excels in capturing the nuances of domain-specific information characterized by diverse distributions. GCDG demonstrates the potential to reduce the target risk and encourage flat minima, improving the generalizability. Extensive experiments show GCDG's comparable performance on five DG benchmarks and one face anti-spoofing dataset, seamlessly integrating into existing DG methods with consistent improvements.

Paper Structure

This paper contains 22 sections, 2 theorems, 15 equations, 4 figures, 15 tables, 1 algorithm.

Key Result

theorem 1

For feature extractor $f$, if features $Z_1 = f(X_1), \cdots, Z_M = f(X_M)$ across M source domains are domain-invariant, i.e., $p(Z_1, Y) = \cdots = p(Z_M, Y)$, then $\inf_{g} \sum_{i = 1}^{M}\mathbb{E}_{p_i}[\ell_{CE}(g(Z_i), Y)] \\- \inf_{h} \sum_{i = 1}^{M} \mathbb{E}_{p_i}[\ell_{CE}(h(X_i), Y)]

Figures (4)

  • Figure 1: Comparison of modeling a class between the discriminative linear classifier and the proposed generative classifier in DG. (a) The prevailing linear classifier in DG operates under the assumption of unimodal distribution, encountering substantial challenges when confronted with domain-specific data that exhibits multi-modality. (b) In this paper, we introduce a novel generative classifier to capture the underlying multi-modal distribution present in domain-specific data.
  • Figure 2: The framework of our proposed GCDG. The key innovation is the Heterogeneity Learning Classifier, which is a generative classifier consisting of a mixture of Gaussians for each class and adept at effectively harnessing valuable domain-specific information exhibiting multi-modality. Besides, we introduce Spurious Correlation Blocking to shuffle the neural units containing spurious correlations, mitigating their adverse effect on capturing domain-specific information. Furthermore, Diverse Component Balancing is designed to balance the contributions of diverse components, avoiding underestimating essential ones.
  • Figure 3: (a) Spurious correlations across diverse scenarios may appear as domain-specific information and be mistakenly captured by our proposed HLC, damaging the generalizability. (b) We introduce Spurious Correlation Blocking (SCB) to perturb these spurious correlations, alleviating their detrimental effect on HLC.
  • Figure 4: Visualization of the loss landscapes for ERM, the flatness-aware method SWAD cha2021swad, and the proposed GCDG on PACS. Note that the loss landscape is visualized on the source domains. Notably, our proposed GCDG exhibits superior efficacy in fostering flat minima compared to ERM and the flatness-aware method SWAD.

Theorems & Definitions (4)

  • theorem 1
  • theorem 2
  • proof
  • proof