Out-of-distribution Generalization for Total Variation based Invariant Risk Minimization
Yuanchao Wang, Zhao-Rong Lai, Tianqi Zhong
TL;DR
This work tackles out-of-distribution generalization by recasting invariant risk minimization through a total-variation lens and introducing OOD-TV-IRM, a primal-dual framework where the TV penalty strength is the Lagrangian multiplier $\lambda(\Psi,\Phi)$. The primal update minimizes the invariant risk to learn stable features, while the dual update amplifies the TV penalty to adversarially suppress spurious correlations, aiming for a semi-Nash equilibrium. A convergent primal-dual algorithm is proposed, with neural-network implementations for the feature extractor $\Phi$, environment-inference $\rho$, and penalty scheduler $\lambda$ that enable practical training. Across seven real and synthetic datasets, the OOD-TV-IRM and OOD-TV-Minimax methods consistently improve mean and worst-case performance over IRM and IRM-TV baselines, demonstrating stronger robustness to distribution shifts and practical applicability.
Abstract
Invariant risk minimization is an important general machine learning framework that has recently been interpreted as a total variation model (IRM-TV). However, how to improve out-of-distribution (OOD) generalization in the IRM-TV setting remains unsolved. In this paper, we extend IRM-TV to a Lagrangian multiplier model named OOD-TV-IRM. We find that the autonomous TV penalty hyperparameter is exactly the Lagrangian multiplier. Thus OOD-TV-IRM is essentially a primal-dual optimization model, where the primal optimization minimizes the entire invariant risk and the dual optimization strengthens the TV penalty. The objective is to reach a semi-Nash equilibrium where the balance between the training loss and OOD generalization is maintained. We also develop a convergent primal-dual algorithm that facilitates an adversarial learning scheme. Experimental results show that OOD-TV-IRM outperforms IRM-TV in most situations.
