Table of Contents
Fetching ...

Invariant Risk Minimization

Martin Arjovsky, Léon Bottou, Ishaan Gulrajani, David Lopez-Paz

TL;DR

IRM tackles the challenge of distribution shifts by requiring the learned data representation to support an invariant top classifier across multiple environments, thereby promoting out-of-distribution generalization. It formalizes this idea through a penalized objective that enforces invariance via a gradient-based penalty, and provides a concrete, implementable objective that extends to general losses and multivariate outputs. The paper connects invariance to causality, showing that under reasonable diversity conditions invariances align with using direct causal parents of the target, which explains improved OOD performance. Empirically, IRM outperforms ERM and prior methods on synthetic data and Colored MNIST, illustrating stronger generalization under distribution shifts and yielding more robust, causally-faithful predictors.

Abstract

We introduce Invariant Risk Minimization (IRM), a learning paradigm to estimate invariant correlations across multiple training distributions. To achieve this goal, IRM learns a data representation such that the optimal classifier, on top of that data representation, matches for all training distributions. Through theory and experiments, we show how the invariances learned by IRM relate to the causal structures governing the data and enable out-of-distribution generalization.

Invariant Risk Minimization

TL;DR

IRM tackles the challenge of distribution shifts by requiring the learned data representation to support an invariant top classifier across multiple environments, thereby promoting out-of-distribution generalization. It formalizes this idea through a penalized objective that enforces invariance via a gradient-based penalty, and provides a concrete, implementable objective that extends to general losses and multivariate outputs. The paper connects invariance to causality, showing that under reasonable diversity conditions invariances align with using direct causal parents of the target, which explains improved OOD performance. Empirically, IRM outperforms ERM and prior methods on synthetic data and Colored MNIST, illustrating stronger generalization under distribution shifts and yielding more robust, causally-faithful predictors.

Abstract

We introduce Invariant Risk Minimization (IRM), a learning paradigm to estimate invariant correlations across multiple training distributions. To achieve this goal, IRM learns a data representation such that the optimal classifier, on top of that data representation, matches for all training distributions. Through theory and experiments, we show how the invariances learned by IRM relate to the causal structures governing the data and enable out-of-distribution generalization.

Paper Structure

This paper contains 28 sections, 4 theorems, 1 equation, 6 figures, 1 table.

Figures (6)

  • Figure 1: Different measures of invariance lead to different optimization landscapes in our Example \ref{['ex:example']}. The naïve approach of measuring the distance between optimal classifiers $\mathbb{D}_\mathrm{dist}$ leads to a discontinuous penalty (solid blue unregularized, dashed orange regularized). In contrast, the penalty $\mathbb{D}_{\mathrm{lin}}$ does not exhibit these problems.
  • Figure 2: The solutions of the invariant linear predictors $v = \Phi^\top w$ coincide with the intersection of the ellipsoids representing the orthogonality condition $v^\top\nabla{R^e(v)}=0$.
  • Figure 3: In our synthetic experiments, the task is to predict $Y^e$ from $X^e = S(Z^e_1, Z^e_2)$.
  • Figure 4: Average errors on causal (plain bars) and non-causal (striped bars) weights for our synthetic experiments. The $y$-axes are in log-scale. See main text for details.
  • Figure 5: $P(y=1|h)$ as a function of $h$ for different models trained on Colored MNIST: (left) an ERM-trained model, (center) an IRM-trained model, and (right) an ERM-trained model which only sees grayscale images and therefore is perfectly invariant by construction. IRM learns approximate invariance from data alone and generalizes well to the test environment.
  • ...and 1 more figures

Theorems & Definitions (9)

  • Example 1
  • Proposition 2
  • Definition 3
  • Theorem 4
  • Definition 5
  • Definition 6
  • Definition 7
  • Theorem 9
  • Theorem 10