Weighted Risk Invariance: Domain Generalization under Invariant Feature Shift
Gina Wong, Joshua Gleason, Rama Chellappa, Yoav Wald, Anqi Liu
TL;DR
The paper tackles OOD generalization under distribution shifts by targeting invariant covariate shift, where the conditional $p(Y|X_{\text{inv}})$ remains stable but the marginal $p(X_{\text{inv}})$ shifts. It introduces Weighted Risk Invariance (WRI), which enforces invariance of reweighted losses across environments using density-based weighting of the invariant features, and proves invariance properties in a linear-Gaussian setting. An alternating minimization algorithm is proposed to jointly learn the predictor and the invariant feature densities, and the method is empirically shown to outperform IRM and VREx under invariant covariate shift on ColoredMNIST and DomainBed benchmarks, with learned densities also enabling useful OOD detection signals. Overall, WRI provides a principled, density-aware approach to robust domain generalization with practical benefits for trustworthiness via density-based uncertainty indication.
Abstract
Learning models whose predictions are invariant under multiple environments is a promising approach for out-of-distribution generalization. Such models are trained to extract features $X_{\text{inv}}$ where the conditional distribution $Y \mid X_{\text{inv}}$ of the label given the extracted features does not change across environments. Invariant models are also supposed to generalize to shifts in the marginal distribution $p(X_{\text{inv}})$ of the extracted features $X_{\text{inv}}$, a type of shift we call an $\textit{invariant covariate shift}$. However, we show that proposed methods for learning invariant models underperform under invariant covariate shift, either failing to learn invariant models$\unicode{x2014}$even for data generated from simple and well-studied linear-Gaussian models$\unicode{x2014}$or having poor finite-sample performance. To alleviate these problems, we propose $\textit{weighted risk invariance}$ (WRI). Our framework is based on imposing invariance of the loss across environments subject to appropriate reweightings of the training examples. We show that WRI provably learns invariant models, i.e. discards spurious correlations, in linear-Gaussian settings. We propose a practical algorithm to implement WRI by learning the density $p(X_{\text{inv}})$ and the model parameters simultaneously, and we demonstrate empirically that WRI outperforms previous invariant learning methods under invariant covariate shift.
