Table of Contents
Fetching ...

Out-of-Distribution Generalization via Risk Extrapolation (REx)

David Krueger, Ethan Caballero, Joern-Henrik Jacobsen, Amy Zhang, Jonathan Binas, Dinghuai Zhang, Remi Le Priol, Aaron Courville

TL;DR

Out-of-distribution generalization is addressed by Risk Extrapolation (REx), a robust optimization approach that aims to flatten the risk landscape across training domains by matching training risks and extrapolating to unseen distributions. The authors introduce MM-REx and a simpler V-REx variant, connecting risk-based extrapolation to invariant prediction and potential causal discovery. They provide theoretical results linking equalized risks to learning the target's causal mechanism under common SCM assumptions, and demonstrate empirical gains over IRM on tasks involving covariate shift and interventional shifts, including colored MNIST, SEMs, DomainBed, and RL with partial observability. The work highlights trade-offs between robustness to causally induced shifts and covariate shift and positions REx as a practical alternative to invariant risk minimization in complex shift settings.

Abstract

Distributional shift is one of the major obstacles when transferring machine learning prediction systems from the lab to the real world. To tackle this problem, we assume that variation across training domains is representative of the variation we might encounter at test time, but also that shifts at test time may be more extreme in magnitude. In particular, we show that reducing differences in risk across training domains can reduce a model's sensitivity to a wide range of extreme distributional shifts, including the challenging setting where the input contains both causal and anti-causal elements. We motivate this approach, Risk Extrapolation (REx), as a form of robust optimization over a perturbation set of extrapolated domains (MM-REx), and propose a penalty on the variance of training risks (V-REx) as a simpler variant. We prove that variants of REx can recover the causal mechanisms of the targets, while also providing some robustness to changes in the input distribution ("covariate shift"). By appropriately trading-off robustness to causally induced distributional shifts and covariate shift, REx is able to outperform alternative methods such as Invariant Risk Minimization in situations where these types of shift co-occur.

Out-of-Distribution Generalization via Risk Extrapolation (REx)

TL;DR

Out-of-distribution generalization is addressed by Risk Extrapolation (REx), a robust optimization approach that aims to flatten the risk landscape across training domains by matching training risks and extrapolating to unseen distributions. The authors introduce MM-REx and a simpler V-REx variant, connecting risk-based extrapolation to invariant prediction and potential causal discovery. They provide theoretical results linking equalized risks to learning the target's causal mechanism under common SCM assumptions, and demonstrate empirical gains over IRM on tasks involving covariate shift and interventional shifts, including colored MNIST, SEMs, DomainBed, and RL with partial observability. The work highlights trade-offs between robustness to causally induced shifts and covariate shift and positions REx as a practical alternative to invariant risk minimization in complex shift settings.

Abstract

Distributional shift is one of the major obstacles when transferring machine learning prediction systems from the lab to the real world. To tackle this problem, we assume that variation across training domains is representative of the variation we might encounter at test time, but also that shifts at test time may be more extreme in magnitude. In particular, we show that reducing differences in risk across training domains can reduce a model's sensitivity to a wide range of extreme distributional shifts, including the challenging setting where the input contains both causal and anti-causal elements. We motivate this approach, Risk Extrapolation (REx), as a form of robust optimization over a perturbation set of extrapolated domains (MM-REx), and propose a penalty on the variance of training risks (V-REx) as a simpler variant. We prove that variants of REx can recover the causal mechanisms of the targets, while also providing some robustness to changes in the input distribution ("covariate shift"). By appropriately trading-off robustness to causally induced distributional shifts and covariate shift, REx is able to outperform alternative methods such as Invariant Risk Minimization in situations where these types of shift co-occur.

Paper Structure

This paper contains 47 sections, 6 theorems, 20 equations, 16 figures, 9 tables.

Key Result

Theorem 1

Given a Linear SEM, $X_i \leftarrow \sum_{j \neq i} \beta_{(i,j)} X_j + \varepsilon_i$, with $Y \doteq X_0$, and a predictor $f_\beta(X) \doteq \sum_{j: j > 0} \beta_j X_j + \varepsilon_j$ that satisfies REx (with mean-squared error) over a perturbation set of domains that contains 3 distinct $do()$

Figures (16)

  • Figure 1: Left: Robust optimization optimizes worst-case performance over the convex hull of training distributions. Right: By extrapolating risks, REx encourages robustness to larger shifts. Here $e_1, e_2,$ and $e_3$ represent training distributions, and $\vv{P^1(X,Y)}$, $\vv{P^2(X,Y)}$ represent some particular directions of variation in the affine space of quasiprobability distributions over $(X,Y)$.
  • Figure 2: Training accuracies (left) and risks (right) on colored MNIST domains with varying $P(Y=0 | \mathrm{color=red})$ after 500 epochs. Dots represent training risks, lines represent test risks on different domains. Increasing the V-REx penalty ($\beta$) leads to a flatter "risk plane" and more consistent performance across domains, as the model learns to ignore color in favor of shape-based invariant prediction. Note that $\beta=100$ gives the best worst-case risk across the 2 training domains, and so would be the solution preferred by DRO sagawa2019distributionally. This demonstrates that REx's counter-intuitive propensity to increase training risks can be necessary for good OOD performance.
  • Figure 3: Extrapolation can yield a distribution with negative$P(x)$ for some $x$. Left:$P(x)$ for domains $e_1$ and $e_2$. Right: Point-wise interpolation/extrapolation of $P^{e_1}(x)$ and $P^{e_2}(x)$. Since MM-REx target worst-case robustness across extrapolated domains, it can provide robustness to such shifts in P(X) (covariate shift).
  • Figure 4: REx outperforms IRM on Colored MNIST variants that include covariate shift. The x-axis indexes increasing amount of shift between training distributions, with $p=0$ corresponding to disjoint supports. Left: class imbalance, Center: shape imbalance, Right: color imbalance.
  • Figure 5: Performance and standard error on walker_walk (top), finger_spin (bottom).
  • ...and 11 more figures

Theorems & Definitions (10)

  • Theorem 1
  • Theorem 2
  • Theorem 1
  • proof
  • Theorem 2
  • proof
  • Proposition 1
  • proof
  • Proposition 2
  • proof