Table of Contents
Fetching ...

Gradient Matching for Domain Generalization

Yuge Shi, Jeffrey Seely, Philip H. S. Torr, N. Siddharth, Awni Hannun, Nicolas Usunier, Gabriel Synnaeve

TL;DR

This work tackles domain generalization by proposing Inter-Domain Gradient Matching (IDGM), which encourages invariant input-output mappings by maximizing the gradient inner product across source domains. To avoid expensive second-order optimization, they introduce Fish, a first-order meta-learning–style algorithm that approximates IDGM and scales to multi-domain settings. Empirically, Fish delivers competitive or state-of-the-art performance on the Wilds and DomainBed benchmarks across vision and language tasks and survives across diverse architectures, while providing clear improvements in gradient alignment over standard ERM. The approach offers a practical, scalable mechanism to reduce reliance on domain-specific spurious correlations and promote robust generalization in real-world deployment.

Abstract

Machine learning systems typically assume that the distributions of training and test sets match closely. However, a critical requirement of such systems in the real world is their ability to generalize to unseen domains. Here, we propose an inter-domain gradient matching objective that targets domain generalization by maximizing the inner product between gradients from different domains. Since direct optimization of the gradient inner product can be computationally prohibitive -- requires computation of second-order derivatives -- we derive a simpler first-order algorithm named Fish that approximates its optimization. We demonstrate the efficacy of Fish on 6 datasets from the Wilds benchmark, which captures distribution shift across a diverse range of modalities. Our method produces competitive results on these datasets and surpasses all baselines on 4 of them. We perform experiments on both the Wilds benchmark, which captures distribution shift in the real world, as well as datasets in DomainBed benchmark that focuses more on synthetic-to-real transfer. Our method produces competitive results on both benchmarks, demonstrating its effectiveness across a wide range of domain generalization tasks.

Gradient Matching for Domain Generalization

TL;DR

This work tackles domain generalization by proposing Inter-Domain Gradient Matching (IDGM), which encourages invariant input-output mappings by maximizing the gradient inner product across source domains. To avoid expensive second-order optimization, they introduce Fish, a first-order meta-learning–style algorithm that approximates IDGM and scales to multi-domain settings. Empirically, Fish delivers competitive or state-of-the-art performance on the Wilds and DomainBed benchmarks across vision and language tasks and survives across diverse architectures, while providing clear improvements in gradient alignment over standard ERM. The approach offers a practical, scalable mechanism to reduce reliance on domain-specific spurious correlations and promote robust generalization in real-world deployment.

Abstract

Machine learning systems typically assume that the distributions of training and test sets match closely. However, a critical requirement of such systems in the real world is their ability to generalize to unseen domains. Here, we propose an inter-domain gradient matching objective that targets domain generalization by maximizing the inner product between gradients from different domains. Since direct optimization of the gradient inner product can be computationally prohibitive -- requires computation of second-order derivatives -- we derive a simpler first-order algorithm named Fish that approximates its optimization. We demonstrate the efficacy of Fish on 6 datasets from the Wilds benchmark, which captures distribution shift across a diverse range of modalities. Our method produces competitive results on these datasets and surpasses all baselines on 4 of them. We perform experiments on both the Wilds benchmark, which captures distribution shift in the real world, as well as datasets in DomainBed benchmark that focuses more on synthetic-to-real transfer. Our method produces competitive results on both benchmarks, demonstrating its effectiveness across a wide range of domain generalization tasks.

Paper Structure

This paper contains 35 sections, 1 theorem, 20 equations, 9 figures, 16 tables, 6 algorithms.

Key Result

Theorem 3.1

Given twice-differentiable model with parameters $\theta$ and objective $l$. Let us define the following: where $\bar{G}=\frac{1}{S}\sum^S_{s=1} G_s$ and is the full gradient of ERM. Then we have

Figures (9)

  • Figure 1: Isometric projection of training with ERM (blue) vs. our IDGM objective (dark blue), using data from \ref{['fig:toy']}.
  • Figure 2: All domains contain 3 types of inputs $x_1, x_2$ and $x_3$, each depicted in one column. 1$^{st}$ col.: $x_1=[0,0,0,0]$, $y=0$, makes up for $50\%$ of each dataset; 2$^{nd}$ col.: $x_2$ changes for each domain, $y=1$ always. $40\%$ of each dataset; 3$^{rd}$ col.: $x_3=[1,0,0,0]$, $30\%$ of $y=1$ and $70\%$ of $y=0$. $10\%$ of each dataset.
  • Figure 3: CdSprites-N train and test splits. Each 3x3 grid in train (e.g. yellow block) represents one domain.
  • Figure 4: Performance on CdSprites-N, with $N \in [5, 50]$
  • Figure 5: Gradient inner product values during the training for CdSprites-N (N=15).
  • ...and 4 more figures

Theorems & Definitions (1)

  • Theorem 3.1