Table of Contents
Fetching ...

On the Generalization and Causal Explanation in Self-Supervised Learning

Wenwen Qiang, Zeen Song, Ziyin Gu, Jiangmeng Li, Changwen Zheng, Fuchun Sun, Hui Xiong

TL;DR

This work on generalization and causal explanation in SSL identifies that early-layer features consistently learn generalizable information while last-layer features tend to memorize training data. It introduces coding rate reduction as an informative proxy for overfitting and presents Undoing Memorization Mechanism (UMM), a bi-level optimization that aligns last-layer feature distributions with early-layer priors to maximize $ riangle R$. Through a causal-model lens, the authors explain how UMM helps recover generalization by focusing on task-relevant information, supported by extensive experiments across unsupervised, semi-supervised, transfer, and out-of-distribution settings. The results show UMM yields robust improvements over multiple SSL methods and datasets, highlighting its practical potential to enhance SSL generalization in real-world tasks. Overall, the paper offers a principled, information-theoretic and causal framework for diagnosing and mitigating SSL overfitting with broad applicability.

Abstract

Self-supervised learning (SSL) methods learn from unlabeled data and achieve high generalization performance on downstream tasks. However, they may also suffer from overfitting to their training data and lose the ability to adapt to new tasks. To investigate this phenomenon, we conduct experiments on various SSL methods and datasets and make two observations: (1) Overfitting occurs abruptly in later layers and epochs, while generalizing features are learned in early layers for all epochs; (2) Coding rate reduction can be used as an indicator to measure the degree of overfitting in SSL models. Based on these observations, we propose Undoing Memorization Mechanism (UMM), a plug-and-play method that mitigates overfitting of the pre-trained feature extractor by aligning the feature distributions of the early and the last layers to maximize the coding rate reduction of the last layer output. The learning process of UMM is a bi-level optimization process. We provide a causal analysis of UMM to explain how UMM can help the pre-trained feature extractor overcome overfitting and recover generalization. We also demonstrate that UMM significantly improves the generalization performance of SSL methods on various downstream tasks.

On the Generalization and Causal Explanation in Self-Supervised Learning

TL;DR

This work on generalization and causal explanation in SSL identifies that early-layer features consistently learn generalizable information while last-layer features tend to memorize training data. It introduces coding rate reduction as an informative proxy for overfitting and presents Undoing Memorization Mechanism (UMM), a bi-level optimization that aligns last-layer feature distributions with early-layer priors to maximize . Through a causal-model lens, the authors explain how UMM helps recover generalization by focusing on task-relevant information, supported by extensive experiments across unsupervised, semi-supervised, transfer, and out-of-distribution settings. The results show UMM yields robust improvements over multiple SSL methods and datasets, highlighting its practical potential to enhance SSL generalization in real-world tasks. Overall, the paper offers a principled, information-theoretic and causal framework for diagnosing and mitigating SSL overfitting with broad applicability.

Abstract

Self-supervised learning (SSL) methods learn from unlabeled data and achieve high generalization performance on downstream tasks. However, they may also suffer from overfitting to their training data and lose the ability to adapt to new tasks. To investigate this phenomenon, we conduct experiments on various SSL methods and datasets and make two observations: (1) Overfitting occurs abruptly in later layers and epochs, while generalizing features are learned in early layers for all epochs; (2) Coding rate reduction can be used as an indicator to measure the degree of overfitting in SSL models. Based on these observations, we propose Undoing Memorization Mechanism (UMM), a plug-and-play method that mitigates overfitting of the pre-trained feature extractor by aligning the feature distributions of the early and the last layers to maximize the coding rate reduction of the last layer output. The learning process of UMM is a bi-level optimization process. We provide a causal analysis of UMM to explain how UMM can help the pre-trained feature extractor overcome overfitting and recover generalization. We also demonstrate that UMM significantly improves the generalization performance of SSL methods on various downstream tasks.
Paper Structure (27 sections, 6 theorems, 25 equations, 8 figures, 13 tables, 1 algorithm)

This paper contains 27 sections, 6 theorems, 25 equations, 8 figures, 13 tables, 1 algorithm.

Key Result

lemma thmcounterlemma

Denote the volume of a region $w \in \Omega$ as ${\rm{vol}}(w)$ and the volume of the affinely transformed region $S(w)$ as ${\rm{vol}}(S(w))$, we have:

Figures (8)

  • Figure 1: The curves of test accuracy versus training epoch for different SSL methods and datasets. The results from (a) - (e) are based on early-layer output while results from (f) - (j) are based on last-layer output. Each result is the average of 5 independent experiments.
  • Figure 2: The curves of coding rate reduction versus training epoch for different SSL methods and datasets. The results from (a) - (e) are based on early-layer output while results from (f) - (j) are based on last-layer output. Each result is the average of 5 independent experiments.
  • Figure 3: The learning framework of the proposed UMM. Given a pre-trained SSL model, UMM aims to fine-tune module $f_{e-l}$ while keeping module $f_{e}$ frozen. The learning objective of UMM is a quadratic optimization problem.
  • Figure 4: SCM for data generation process. $X$ and $X^{aug}$ represent the original and augmented datasets, respectively. $S_{r}$ denotes the task-relevant information, and $S_{ur}$ denotes the task-irrelevant information. $S_{ur}^{*}$ denote the perturbed $S_{ur}$, which can be understood as only changing some of the task-irrelevant information in $S_{ur}$.
  • Figure 5: (a)-(c) represent the curves of test accuracy versus training epoch. (d)-(f) denote the curves of coding rate reduction versus training epoch. All results are based on the last layer output.
  • ...and 3 more figures

Theorems & Definitions (9)

  • lemma thmcounterlemma
  • theorem 1
  • theorem 2
  • lemma 1
  • proof
  • theorem 1
  • proof
  • theorem 2
  • proof