On the Generalization and Causal Explanation in Self-Supervised Learning
Wenwen Qiang, Zeen Song, Ziyin Gu, Jiangmeng Li, Changwen Zheng, Fuchun Sun, Hui Xiong
TL;DR
This work on generalization and causal explanation in SSL identifies that early-layer features consistently learn generalizable information while last-layer features tend to memorize training data. It introduces coding rate reduction as an informative proxy for overfitting and presents Undoing Memorization Mechanism (UMM), a bi-level optimization that aligns last-layer feature distributions with early-layer priors to maximize $ riangle R$. Through a causal-model lens, the authors explain how UMM helps recover generalization by focusing on task-relevant information, supported by extensive experiments across unsupervised, semi-supervised, transfer, and out-of-distribution settings. The results show UMM yields robust improvements over multiple SSL methods and datasets, highlighting its practical potential to enhance SSL generalization in real-world tasks. Overall, the paper offers a principled, information-theoretic and causal framework for diagnosing and mitigating SSL overfitting with broad applicability.
Abstract
Self-supervised learning (SSL) methods learn from unlabeled data and achieve high generalization performance on downstream tasks. However, they may also suffer from overfitting to their training data and lose the ability to adapt to new tasks. To investigate this phenomenon, we conduct experiments on various SSL methods and datasets and make two observations: (1) Overfitting occurs abruptly in later layers and epochs, while generalizing features are learned in early layers for all epochs; (2) Coding rate reduction can be used as an indicator to measure the degree of overfitting in SSL models. Based on these observations, we propose Undoing Memorization Mechanism (UMM), a plug-and-play method that mitigates overfitting of the pre-trained feature extractor by aligning the feature distributions of the early and the last layers to maximize the coding rate reduction of the last layer output. The learning process of UMM is a bi-level optimization process. We provide a causal analysis of UMM to explain how UMM can help the pre-trained feature extractor overcome overfitting and recover generalization. We also demonstrate that UMM significantly improves the generalization performance of SSL methods on various downstream tasks.
