Table of Contents
Fetching ...

On the Out-of-Distribution Generalization of Self-Supervised Learning

Wenwen Qiang, Jingyao Wang, Zeen Song, Jiangmeng Li, Changwen Zheng

TL;DR

This paper addresses the challenge of out-of-distribution generalization in self-supervised learning by reframing SSL training as multi-task learning over a task distribution and identifying spurious correlations arising from unobserved factors. It introduces a causal framework with a post-intervention distribution (PID) that enforces mutual independence between a spurious variable and the anchor, yielding minimax optimal worst-case OOD risk. To realize PID in practice, the authors propose a two-stage method: learning a Regularized Latent Variable Model (RLVM) to capture task-conditioned distributions, and a propensity-balancing score based mini-batch sampling algorithm that constructs batches approximating PID. Theoretical identifiability results and empirical validation across unsupervised, semi-supervised, transfer, and few-shot tasks demonstrate that PID-based batching reduces spurious correlations and substantially improves OOD generalization. Overall, the work provides a principled, causality-grounded approach to robust SSL by transforming batch construction into a tool for controlling distributional biases and achieving reliable generalization in shift-prone environments.

Abstract

In this paper, we focus on the out-of-distribution (OOD) generalization of self-supervised learning (SSL). By analyzing the mini-batch construction during the SSL training phase, we first give one plausible explanation for SSL having OOD generalization. Then, from the perspective of data generation and causal inference, we analyze and conclude that SSL learns spurious correlations during the training process, which leads to a reduction in OOD generalization. To address this issue, we propose a post-intervention distribution (PID) grounded in the Structural Causal Model. PID offers a scenario where the spurious variable and label variable is mutually independent. Besides, we demonstrate that if each mini-batch during SSL training satisfies PID, the resulting SSL model can achieve optimal worst-case OOD performance. This motivates us to develop a batch sampling strategy that enforces PID constraints through the learning of a latent variable model. Through theoretical analysis, we demonstrate the identifiability of the latent variable model and validate the effectiveness of the proposed sampling strategy. Experiments conducted on various downstream OOD tasks demonstrate the effectiveness of the proposed sampling strategy.

On the Out-of-Distribution Generalization of Self-Supervised Learning

TL;DR

This paper addresses the challenge of out-of-distribution generalization in self-supervised learning by reframing SSL training as multi-task learning over a task distribution and identifying spurious correlations arising from unobserved factors. It introduces a causal framework with a post-intervention distribution (PID) that enforces mutual independence between a spurious variable and the anchor, yielding minimax optimal worst-case OOD risk. To realize PID in practice, the authors propose a two-stage method: learning a Regularized Latent Variable Model (RLVM) to capture task-conditioned distributions, and a propensity-balancing score based mini-batch sampling algorithm that constructs batches approximating PID. Theoretical identifiability results and empirical validation across unsupervised, semi-supervised, transfer, and few-shot tasks demonstrate that PID-based batching reduces spurious correlations and substantially improves OOD generalization. Overall, the work provides a principled, causality-grounded approach to robust SSL by transforming batch construction into a tool for controlling distributional biases and achieving reliable generalization in shift-prone environments.

Abstract

In this paper, we focus on the out-of-distribution (OOD) generalization of self-supervised learning (SSL). By analyzing the mini-batch construction during the SSL training phase, we first give one plausible explanation for SSL having OOD generalization. Then, from the perspective of data generation and causal inference, we analyze and conclude that SSL learns spurious correlations during the training process, which leads to a reduction in OOD generalization. To address this issue, we propose a post-intervention distribution (PID) grounded in the Structural Causal Model. PID offers a scenario where the spurious variable and label variable is mutually independent. Besides, we demonstrate that if each mini-batch during SSL training satisfies PID, the resulting SSL model can achieve optimal worst-case OOD performance. This motivates us to develop a batch sampling strategy that enforces PID constraints through the learning of a latent variable model. Through theoretical analysis, we demonstrate the identifiability of the latent variable model and validate the effectiveness of the proposed sampling strategy. Experiments conducted on various downstream OOD tasks demonstrate the effectiveness of the proposed sampling strategy.

Paper Structure

This paper contains 30 sections, 5 theorems, 37 equations, 6 figures, 13 tables, 1 algorithm.

Key Result

Proposition 3.1

Revisiting SSL from a pairwise perspective and assuming that the two samples in each pair satisfy c3:eq1, we can obtain that the learned SSL model will use non-causal factor, i.e., the unobserved latent variable $s$, to measure the similarity or reconstruct in a pair.

Figures (6)

  • Figure 1: The SCM for \ref{['c3:eq1']}.
  • Figure 2: The SCM for $p^{\rm PI}({x^ + },{x^{{\rm{label}}}},s)$.
  • Figure 3: Two specific instances illustrate the variability in the causal relationship between $x^{\rm label}$ and $s$ due to environmental changes. The black squares are variables and the arrows indicate causality.
  • Figure 4: Influence of the hyperparameter $a$.
  • Figure 5: Influence of the hyperparameter $\alpha$.
  • ...and 1 more figures

Theorems & Definitions (9)

  • Proposition 3.1
  • Definition 3.2
  • Theorem 3.4
  • Definition 4.2
  • Theorem 4.3
  • Definition 4.4
  • Definition 4.5
  • Corollary 4.6
  • Theorem 4.7