Table of Contents
Fetching ...

Predicting What You Already Know Helps: Provable Self-Supervised Learning

Jason D. Lee, Qi Lei, Nikunj Saunshi, Jiacheng Zhuo

TL;DR

This work addresses why self-supervised learning (SSL) helps downstream tasks by formalizing approximate conditional independence (ACI) between pretext targets and inputs conditioned on the label $Y$. It shows that learning a representation $\psi(X_1)$ to predict the pretext $X_2$ can implicitly encode $Y$, enabling a linear predictor on top of $\psi$ with favorable sample complexity bounds. The paper derives finite-sample excess-risk bounds under exact CI and extends them to ACI with latent variables, including a topic-model example that achieves $\mathcal{O}(k)$ labeled samples, and connections to nonlinear CCA, SimSiam, and ACE with analogous guarantees. Experiments on simulations and real CV/NLP tasks validate the theory, demonstrating substantial improvements in downstream performance with SSL representations and providing a practical pathway to reducing labeled-data requirements in diverse domains.

Abstract

Self-supervised representation learning solves auxiliary prediction tasks (known as pretext tasks) without requiring labeled data to learn useful semantic representations. These pretext tasks are created solely using the input features, such as predicting a missing image patch, recovering the color channels of an image from context, or predicting missing words in text; yet predicting this \textit{known} information helps in learning representations effective for downstream prediction tasks. We posit a mechanism exploiting the statistical connections between certain {\em reconstruction-based} pretext tasks that guarantee to learn a good representation. Formally, we quantify how the approximate independence between the components of the pretext task (conditional on the label and latent variables) allows us to learn representations that can solve the downstream task by just training a linear layer on top of the learned representation. We prove the linear layer yields small approximation error even for complex ground truth function class and will drastically reduce labeled sample complexity. Next, we show a simple modification of our method leads to nonlinear CCA, analogous to the popular SimSiam algorithm, and show similar guarantees for nonlinear CCA.

Predicting What You Already Know Helps: Provable Self-Supervised Learning

TL;DR

This work addresses why self-supervised learning (SSL) helps downstream tasks by formalizing approximate conditional independence (ACI) between pretext targets and inputs conditioned on the label . It shows that learning a representation to predict the pretext can implicitly encode , enabling a linear predictor on top of with favorable sample complexity bounds. The paper derives finite-sample excess-risk bounds under exact CI and extends them to ACI with latent variables, including a topic-model example that achieves labeled samples, and connections to nonlinear CCA, SimSiam, and ACE with analogous guarantees. Experiments on simulations and real CV/NLP tasks validate the theory, demonstrating substantial improvements in downstream performance with SSL representations and providing a practical pathway to reducing labeled-data requirements in diverse domains.

Abstract

Self-supervised representation learning solves auxiliary prediction tasks (known as pretext tasks) without requiring labeled data to learn useful semantic representations. These pretext tasks are created solely using the input features, such as predicting a missing image patch, recovering the color channels of an image from context, or predicting missing words in text; yet predicting this \textit{known} information helps in learning representations effective for downstream prediction tasks. We posit a mechanism exploiting the statistical connections between certain {\em reconstruction-based} pretext tasks that guarantee to learn a good representation. Formally, we quantify how the approximate independence between the components of the pretext task (conditional on the label and latent variables) allows us to learn representations that can solve the downstream task by just training a linear layer on top of the learned representation. We prove the linear layer yields small approximation error even for complex ground truth function class and will drastically reduce labeled sample complexity. Next, we show a simple modification of our method leads to nonlinear CCA, analogous to the popular SimSiam algorithm, and show similar guarantees for nonlinear CCA.

Paper Structure

This paper contains 58 sections, 33 theorems, 142 equations, 5 figures.

Key Result

Lemma 3.1

If random variables $X_1,X_2,Y$ satisfy Assumption assump:independence, and $\bm{A} \in \mathbb{R}^{\mathcal{Y}\times d_2}$ with $\bm{A}_{y,:} := \mathbb{E}[X_2|Y={\bm{y}}]$ has rank $k=|\mathcal{Y}|$. Then $f^*\equiv {\bm{W}}^* \psi^*$, i.e., $e_\text{apx}(\psi^*) = 0$.

Figures (5)

  • Figure 1: Left two: how MSE scales with $k$ (the dimension of $Y$) and $\epsilon_{CI}$ (Definition \ref{['def-approx-CI']}) with the linear function class. Right two: how MSE scales with $k$ and $\epsilon$ with $\psi^*$ and non-linear function class. Mean of $30$ trials are shown in solid line and one standard error is shown by shadow.
  • Figure 2: Left: Example of the $X_2$ (in the red box of the 1st row), the $X_1$ (out of the red box of the 1st row), the input to the inpainting task (the second row), $\psi(X_1)$ (the 3 row in the red box), and in this example $Y=1967$. Middle: Mean Squared Error comparison of yearbook regression predicting dates. Right: Mean Absolute Error comparison of yearbook regression predicting dates. Experiments are repeated 10 times, with mean shown in solid line and one standard deviation in shadow.
  • Figure 3: Left: MSE of using $\psi$ to predict $Y$ versus using $X_1$ directly to predict $Y$. Using $\psi$ consistently outperforms using $X_1$. Right: MSE of $\psi$ learned with different $n_1$. The MSE scale with $1 / n_2$ as indicated by our analysis. Simulations are repeated 100 times, with the mean shown in solid line and one standard error shown in shadow.
  • Figure 4: Left: Mean Squared Error comparison of predicting gender and predicting date. Right: the spectrum comparison of covariance condition on gender and condition on date.
  • Figure 5: Performance on SST of baseline $\phi_1({\bm{x}}_1)$, i.e. bag-of-words, and learned $\psi({\bm{x}}_1)$ for the two settings. Left: Classification accuracy, Right: Regression MSE.

Theorems & Definitions (93)

  • Lemma 3.1: Approximation error
  • proof : Proof Sketch of Lemma \ref{['lemma:discrete_case_CI']}
  • Example 3.1
  • Theorem 3.2: General conditional independence
  • Remark 3.1
  • Claim 3.3: Closed form solution
  • Lemma 3.4: Approximation error
  • Theorem 3.5
  • Definition 4.1: Approximate conditional independence with function space $\mathcal{H}$
  • Theorem 4.2
  • ...and 83 more