On the Out-of-Distribution Generalization of Self-Supervised Learning
Wenwen Qiang, Jingyao Wang, Zeen Song, Jiangmeng Li, Changwen Zheng
TL;DR
This paper addresses the challenge of out-of-distribution generalization in self-supervised learning by reframing SSL training as multi-task learning over a task distribution and identifying spurious correlations arising from unobserved factors. It introduces a causal framework with a post-intervention distribution (PID) that enforces mutual independence between a spurious variable and the anchor, yielding minimax optimal worst-case OOD risk. To realize PID in practice, the authors propose a two-stage method: learning a Regularized Latent Variable Model (RLVM) to capture task-conditioned distributions, and a propensity-balancing score based mini-batch sampling algorithm that constructs batches approximating PID. Theoretical identifiability results and empirical validation across unsupervised, semi-supervised, transfer, and few-shot tasks demonstrate that PID-based batching reduces spurious correlations and substantially improves OOD generalization. Overall, the work provides a principled, causality-grounded approach to robust SSL by transforming batch construction into a tool for controlling distributional biases and achieving reliable generalization in shift-prone environments.
Abstract
In this paper, we focus on the out-of-distribution (OOD) generalization of self-supervised learning (SSL). By analyzing the mini-batch construction during the SSL training phase, we first give one plausible explanation for SSL having OOD generalization. Then, from the perspective of data generation and causal inference, we analyze and conclude that SSL learns spurious correlations during the training process, which leads to a reduction in OOD generalization. To address this issue, we propose a post-intervention distribution (PID) grounded in the Structural Causal Model. PID offers a scenario where the spurious variable and label variable is mutually independent. Besides, we demonstrate that if each mini-batch during SSL training satisfies PID, the resulting SSL model can achieve optimal worst-case OOD performance. This motivates us to develop a batch sampling strategy that enforces PID constraints through the learning of a latent variable model. Through theoretical analysis, we demonstrate the identifiability of the latent variable model and validate the effectiveness of the proposed sampling strategy. Experiments conducted on various downstream OOD tasks demonstrate the effectiveness of the proposed sampling strategy.
