Sample Weight Averaging for Stable Prediction
Han Yu, Yue He, Renzhe Xu, Dongbai Li, Jiayin Zhang, Wenchao Zou, Peng Cui
TL;DR
This work tackles Out-of-Distribution covariate shift by addressing variance inflation in independence-based sample reweighting methods. It introduces SAmple Weight Averaging (SAWA), which ensembles multiple weight learners with random initializations to produce a diverse, averaged weighting function $\bar{w}(\boldsymbol{X})$, reducing variance without requiring environment labels and enabling parallel computation. The authors provide theoretical justification for the validity of averaged weights and derive a bias-variance decomposition showing variance reduction translates into better coefficient estimation and stable predictions under covariate shift. Empirical results on synthetic and real-world datasets demonstrate consistent improvements in covariate-shift generalization across multiple baselines and tasks, highlighting SAWA's universality and practical impact for robust learning in non-IID settings.
Abstract
The challenge of Out-of-Distribution (OOD) generalization poses a foundational concern for the application of machine learning algorithms to risk-sensitive areas. Inspired by traditional importance weighting and propensity weighting methods, prior approaches employ an independence-based sample reweighting procedure. They aim at decorrelating covariates to counteract the bias introduced by spurious correlations between unstable variables and the outcome, thus enhancing generalization and fulfilling stable prediction under covariate shift. Nonetheless, these methods are prone to experiencing an inflation of variance, primarily attributable to the reduced efficacy in utilizing training samples during the reweighting process. Existing remedies necessitate either environmental labels or substantially higher time costs along with additional assumptions and supervised information. To mitigate this issue, we propose SAmple Weight Averaging (SAWA), a simple yet efficacious strategy that can be universally integrated into various sample reweighting algorithms to decrease the variance and coefficient estimation error, thus boosting the covariate-shift generalization and achieving stable prediction across different environments. We prove its rationality and benefits theoretically. Experiments across synthetic datasets and real-world datasets consistently underscore its superiority against covariate shift.
