Table of Contents
Fetching ...

From Invariant Representations to Invariant Data: Provable Robustness to Spurious Correlations via Noisy Counterfactual Matching

Ruqi Bai, Yao Ji, Zeyu Zhou, David I. Inouye

TL;DR

This work introduces Noisy Counterfactual Matching (NCM), a simple constraint-based method that improves robustness by leveraging even a small number of counterfactual pairs -- improving upon prior works that do not explicitly consider noise.

Abstract

Models that learn spurious correlations from training data often fail when deployed in new environments. While many methods aim to learn invariant representations to address this, they often underperform standard empirical risk minimization (ERM). We propose a data-centric alternative that shifts the focus from learning invariant representations to leveraging invariant data pairs -- pairs of samples that should have the same prediction. We prove that certain counterfactuals naturally satisfy this invariance property. Based on this, we introduce Noisy Counterfactual Matching (NCM), a simple constraint-based method that improves robustness by leveraging even a small number of \emph{noisy} counterfactual pairs -- improving upon prior works that do not explicitly consider noise. For linear causal models, we prove that NCM's test-domain error is bounded by its in-domain error plus a term dependent on the counterfactuals' quality and diversity. Experiments on synthetic data validate our theory, and we demonstrate NCM's effectiveness on real-world datasets.

From Invariant Representations to Invariant Data: Provable Robustness to Spurious Correlations via Noisy Counterfactual Matching

TL;DR

This work introduces Noisy Counterfactual Matching (NCM), a simple constraint-based method that improves robustness by leveraging even a small number of counterfactual pairs -- improving upon prior works that do not explicitly consider noise.

Abstract

Models that learn spurious correlations from training data often fail when deployed in new environments. While many methods aim to learn invariant representations to address this, they often underperform standard empirical risk minimization (ERM). We propose a data-centric alternative that shifts the focus from learning invariant representations to leveraging invariant data pairs -- pairs of samples that should have the same prediction. We prove that certain counterfactuals naturally satisfy this invariance property. Based on this, we introduce Noisy Counterfactual Matching (NCM), a simple constraint-based method that improves robustness by leveraging even a small number of \emph{noisy} counterfactual pairs -- improving upon prior works that do not explicitly consider noise. For linear causal models, we prove that NCM's test-domain error is bounded by its in-domain error plus a term dependent on the counterfactuals' quality and diversity. Experiments on synthetic data validate our theory, and we demonstrate NCM's effectiveness on real-world datasets.

Paper Structure

This paper contains 44 sections, 8 theorems, 40 equations, 13 figures, 9 tables, 1 algorithm.

Key Result

Proposition 1

Given a spurious correlation latent SCM class $\mathcal{M}_\mathcal{E}$ and a strictly convex loss function $\ell$, any observed counterfactual pair $({\bm{x}}_e, {\bm{x}}_{e'})$ between $\mathcal{M}_e \in \mathcal{M}_\mathcal{E}$ and $\mathcal{M}_{e'} \in \mathcal{M}_\mathcal{E}$ will be an invaria

Figures (13)

  • Figure 1: While ERM $\hat{\theta}$ on the training domains (circles and triangles) is not robust to the change in spurious feature in the unseen test domain (pluses), a robust linear classifier $\theta^*$ can be estimated by making the classifier orthogonal to the difference between a single invariant pair (green line). The color represents label ${\textnormal{y}}$.
  • Figure 2: Result on the synthetic dataset with NCM. We report both in-domain test accuracy (in-test accuracy) and test domain accuracy (test accuracy). We choose $m=100$ and $\lvert\mathcal{I}(\mathcal{F_\mathcal{E}})\rvert=20$ (denoted by vertical dash line). The horizontal lines represent the ERM accuracy and oracle accuracy (ERM train on test domain). The vertical line at $20$ denotes $\mathcal{I}(\mathcal{F_E}).$$\varepsilon=0$ means oracle CF pairs. The solid curves represent the mean over 10 runs with shaded regions indicating standard deviations.
  • Figure 3: Illustration of the latent causal model. The ancestors of the target ${\textnormal{y}}$ are ${\mathbf{z}}_1, {\mathbf{z}}_2$, which are assumed to be invariant across domains (see \ref{['ass:spurious-corr']}). On the other hand, ${\mathbf{z}}_3, {\mathbf{z}}_4$ are spurious features. To be specific, ${\mathbf{z}}_3$ is confounded with ${\textnormal{y}}$, and ${\mathbf{z}}_4$ is descendant of ${\textnormal{y}}$. Because they are not ancestors of ${\textnormal{y}}$, thus they are spurious.
  • Figure 4: An illustration of oracle counterfactual pairs represented by our model, where $f_1$ and $f_2$ are two SCMs' solution function for domain $1$ and domain $2$, $g_{{\mathbf{x}}}$ is the observation function from ${\mathbf{z}}$ to ${\mathbf{x}}.$ In this figure, we do not plot the prediction target ${\textnormal{y}}$ and correspondingly $g_{{\textnormal{y}}}.$
  • Figure 5: Illustration of robustness prediction corresponding to non-latent and latent causal variables.
  • ...and 8 more figures

Theorems & Definitions (21)

  • Definition 1: Structural Causal Model pearl2009causality
  • Definition 2: Intervention Set
  • Definition 3: Counterfactual Pair
  • Definition 4: Class of Latent Domain SCMs
  • Definition 5: Optimally Robust Classifier
  • Definition 6: Invariant Pair
  • Proposition 1: Spurious Counterfactuals are Invariant Pairs
  • Theorem 1: Test-Domain Error Bound for NCM with Linear Models
  • Corollary 2: Test-Domain Error Bound for ERM with Linear Models
  • Corollary 3: Test-Domain Bound in Terms of Counterfactual Noise
  • ...and 11 more