Table of Contents
Fetching ...

Weighted Risk Invariance: Domain Generalization under Invariant Feature Shift

Gina Wong, Joshua Gleason, Rama Chellappa, Yoav Wald, Anqi Liu

TL;DR

The paper tackles OOD generalization under distribution shifts by targeting invariant covariate shift, where the conditional $p(Y|X_{\text{inv}})$ remains stable but the marginal $p(X_{\text{inv}})$ shifts. It introduces Weighted Risk Invariance (WRI), which enforces invariance of reweighted losses across environments using density-based weighting of the invariant features, and proves invariance properties in a linear-Gaussian setting. An alternating minimization algorithm is proposed to jointly learn the predictor and the invariant feature densities, and the method is empirically shown to outperform IRM and VREx under invariant covariate shift on ColoredMNIST and DomainBed benchmarks, with learned densities also enabling useful OOD detection signals. Overall, WRI provides a principled, density-aware approach to robust domain generalization with practical benefits for trustworthiness via density-based uncertainty indication.

Abstract

Learning models whose predictions are invariant under multiple environments is a promising approach for out-of-distribution generalization. Such models are trained to extract features $X_{\text{inv}}$ where the conditional distribution $Y \mid X_{\text{inv}}$ of the label given the extracted features does not change across environments. Invariant models are also supposed to generalize to shifts in the marginal distribution $p(X_{\text{inv}})$ of the extracted features $X_{\text{inv}}$, a type of shift we call an $\textit{invariant covariate shift}$. However, we show that proposed methods for learning invariant models underperform under invariant covariate shift, either failing to learn invariant models$\unicode{x2014}$even for data generated from simple and well-studied linear-Gaussian models$\unicode{x2014}$or having poor finite-sample performance. To alleviate these problems, we propose $\textit{weighted risk invariance}$ (WRI). Our framework is based on imposing invariance of the loss across environments subject to appropriate reweightings of the training examples. We show that WRI provably learns invariant models, i.e. discards spurious correlations, in linear-Gaussian settings. We propose a practical algorithm to implement WRI by learning the density $p(X_{\text{inv}})$ and the model parameters simultaneously, and we demonstrate empirically that WRI outperforms previous invariant learning methods under invariant covariate shift.

Weighted Risk Invariance: Domain Generalization under Invariant Feature Shift

TL;DR

The paper tackles OOD generalization under distribution shifts by targeting invariant covariate shift, where the conditional remains stable but the marginal shifts. It introduces Weighted Risk Invariance (WRI), which enforces invariance of reweighted losses across environments using density-based weighting of the invariant features, and proves invariance properties in a linear-Gaussian setting. An alternating minimization algorithm is proposed to jointly learn the predictor and the invariant feature densities, and the method is empirically shown to outperform IRM and VREx under invariant covariate shift on ColoredMNIST and DomainBed benchmarks, with learned densities also enabling useful OOD detection signals. Overall, WRI provides a principled, density-aware approach to robust domain generalization with practical benefits for trustworthiness via density-based uncertainty indication.

Abstract

Learning models whose predictions are invariant under multiple environments is a promising approach for out-of-distribution generalization. Such models are trained to extract features where the conditional distribution of the label given the extracted features does not change across environments. Invariant models are also supposed to generalize to shifts in the marginal distribution of the extracted features , a type of shift we call an . However, we show that proposed methods for learning invariant models underperform under invariant covariate shift, either failing to learn invariant modelseven for data generated from simple and well-studied linear-Gaussian modelsor having poor finite-sample performance. To alleviate these problems, we propose (WRI). Our framework is based on imposing invariance of the loss across environments subject to appropriate reweightings of the training examples. We show that WRI provably learns invariant models, i.e. discards spurious correlations, in linear-Gaussian settings. We propose a practical algorithm to implement WRI by learning the density and the model parameters simultaneously, and we demonstrate empirically that WRI outperforms previous invariant learning methods under invariant covariate shift.
Paper Structure (42 sections, 4 theorems, 40 equations, 10 figures, 8 tables, 1 algorithm)

This paper contains 42 sections, 4 theorems, 40 equations, 10 figures, 8 tables, 1 algorithm.

Key Result

Proposition 0

Let Assumption assumption:setting hold over a set of environments $\mathcal{E}\subseteq{{\mathcal{P}}}$. If a predictor $f$ is spurious-free over $\mathcal{E}$, then for every pair of environments $e_i, e_j \in \mathcal{E}$, their weighted risks are equal if their respective weighting functions $\al

Figures (10)

  • Figure 1: Top row: a Gaussian setup exhibiting both heteroskedasticity and invariant covariate shift. Note that for invariant features $x_{inv}$, $p(y|x_{inv})$ is the same across environments, while $p(y|x_{spu})$ changes across environments, following our causal model. Bottom row: we visualize the performance of different algorithms under this setup, and find the WRI objective recovers a more invariant predictor than the other algorithms. Specifically, the WRI predictor bases its predictions more on the invariant features $x_{inv}$ and less on the spurious features $x_{spu}$, so it learns a decision boundary that is vertical in this case. The full parameters for this simulation can be found in Appendix \ref{['appendix:experiments']}.
  • Figure 2: Causal graph depicting our data-generating process. The environment $E$ is dashed to emphasize it takes on unobserved values at test time.
  • Figure 3: (a) We demonstrate the finite sample behavior of ERM, IRM, VREx, and WRI methods in the case where the distribution shift is large. We sample data from the distributions shown in (b). Even as the number of samples increases, ERM, IRM and VREx methods continue to select spurious classifiers while the WRI method quickly converges to an invariant classifier.
  • Figure 4: The WRI penalty is more sensitive to spurious classifiers than the IRM penalty under invariant covariate shift. (a) We start with two data distributions ($p_1$ and $p_2$) and increase the distance between their means from 0 to 2 in the invariant direction, essentially creating invariant covariate shift ranging in degree from the diagram in (b) to the diagram in (c). We then search through the set of challenging, slightly spurious classifiers with near-optimal ERM loss, so that we must rely on the invariant penalty to reject these classifiers. We plot how the minimum IRM and WRI penalties in these regions change as covariate shift increases.
  • Figure C.1: Given data of the form $[\mathop{\mathrm{\mathbf{x}}}\nolimits_{\text{inv}}, \mathop{\mathrm{\mathbf{x}}}\nolimits_{\text{spu}}]$ following our setting in equation \ref{['eq:regression_setting']}, we train a model with a two-parameter feature layer, where one weight scales the invariant features and one weight scales the spurious features. For both the WRI and ERM methods, we do two initializations where we initialize the model with equal feature weights and with weights on the spurious features only. We find that optimizing WRI always leads to the spurious weights going to zero, whereas optimizing ERM converges to nonzero spurious weights.
  • ...and 5 more figures

Theorems & Definitions (9)

  • Definition 1
  • Proposition 0
  • Definition 2
  • Theorem 0
  • Proposition 0
  • proof
  • Definition 3
  • Theorem 0
  • proof