Table of Contents
Fetching ...

Understanding Self-Supervised Learning via Latent Distribution Matching

Fabian A Mikulasch, Friedemann Zenke

Abstract

Self-supervised learning (SSL) excels at finding general-purpose latent representations from complex data, yet lacks a unifying theoretical framework that explains the diverse existing methods and guides the design of new ones. We cast SSL as latent distribution matching (LDM): learning representations that maximize their log-probability under an assumed latent model (alignment), while maximizing latent entropy to prevent collapse (uniformity). This view unifies independent component analysis with contrastive, non-contrastive, and predictive SSL methods, including stop gradient approaches. Leveraging LDM, we derive a nonlinear, sampling-free Bayesian filtering model with a Kalman-based predictor for high-dimensional timeseries. We further prove that predictive LDM yields identifiable latent representations under mild assumptions, even with nonlinear predictors. Overall, LDM clarifies the assumptions behind established SSL methods and provides principled guidance for developing new approaches.

Understanding Self-Supervised Learning via Latent Distribution Matching

Abstract

Self-supervised learning (SSL) excels at finding general-purpose latent representations from complex data, yet lacks a unifying theoretical framework that explains the diverse existing methods and guides the design of new ones. We cast SSL as latent distribution matching (LDM): learning representations that maximize their log-probability under an assumed latent model (alignment), while maximizing latent entropy to prevent collapse (uniformity). This view unifies independent component analysis with contrastive, non-contrastive, and predictive SSL methods, including stop gradient approaches. Leveraging LDM, we derive a nonlinear, sampling-free Bayesian filtering model with a Kalman-based predictor for high-dimensional timeseries. We further prove that predictive LDM yields identifiable latent representations under mild assumptions, even with nonlinear predictors. Overall, LDM clarifies the assumptions behind established SSL methods and provides principled guidance for developing new approaches.

Paper Structure

This paper contains 53 sections, 6 theorems, 76 equations, 14 figures, 3 tables.

Key Result

Theorem 1

Assume that: Then, at the optimum of the DM objective, the learned representation recovers the true latent variables up to an affine transformation.

Figures (14)

  • Figure 1: We formulate SSL as a distribution matching problem in which the transformed data distribution $R(z,z')$ is matched to the latent model $P_\theta(z,z')$. The transformation is deterministic $R(z|x)=\delta(z-f(x))$, where $f(x)$ is a deep network. The model likelihood $\log P_\theta$ and latent entropy $H_R$ correspond to alignment and uniformity terms in the loss function wang2020understanding.
  • Figure 2: Source recovery with DM in linear ICA. A Linear ICA assumes that the data distribution has independent factors, that can be recovered by aligning them with the correct underlying independent distribution cardoso2002infomax. B Distributions of pixel intensities in natural images are non-Gaussian hyvarinen1999independent. In contrast, mixed images are closer to Gaussian, as expected from the central limit theorem. Disentanglement proceeds by learning $W$ to recover an assumed short-tailed distribution (red dashed line). C Also Gaussian sources can be disentangled, which, however, requires more assumptions on the data generating process. Here we recover the outputs of two Ornstein-Uhlenbeck processes assuming known variances and that $W$ is volume preserving (determinant of $|W|=1$).
  • Figure 3: Comparison of learned image representations on CIFAR-10. A The eigenspectrum of the learned representations generally decays more slowly for parametric entropy estimators, both on the plane (solid) and the sphere (dashed). Whether or not MI was maximized (+ MI) had little impact on the spectrum. The observed cutoff at low double digits is consistent with previous estimates of intrinsic dimensionality of CIFAR-10 pope2021intrinsic. B T-SNE embeddings maaten2008visualizing of representations paint a similar picture in which MI maximization has little impact. Color denotes label.
  • Figure 4: Predictive distribution matching in latent space using a nonlinear Bayesian filtering model with Kalman-based predictor. A Example frames of synthetic dataset of a high dimensional noisy observable with linear latent dynamics. The red line denotes ground truth position. See Appendix, Fig. \ref{['fig:A4a']} for a more nonlinear task. B We use a Kalman filter backbone for the predictor $P_\theta(z_t|z_{:t})$ with hidden states $h_t$ and latent "observations" $z_t$. C Estimated position and ground truth over time. After learning, we can linearly decode the true position and uncertainty from the Kalman hidden states $h_t$. Training with MI maximization (kNN entropy estimator) or without (stopgrad) results in approximately equal performance. D The cosine similarity of gradients of network weights w.r.t. entropy estimators also shows that both approaches lead to similar optimization. Similarity increases with training when the predictor becomes more accurate and so does the stopgrad entropy estimator. E Experimental setup in grosmark2016diversity. F The learned latent states $h_t$ are arranged in a circle according to the position and direction of the rat. One example trajectory over time is displayed in gray. G Example neural spiking timeseries and the estimated position based on the model. Shaded areas denote 95% confidence intervals based on model covariances $\Sigma_{h_t}$ (see Appendix \ref{['app:simulation']}).
  • Figure 5: System identification through predictive DM. A Forcing prediction errors into a Gaussian form leads to local linearization of the relation between true and recovered latent variables. B Schematic of nonlinear prediction task. Trajectory noise in the true latent space is Gaussian to enable identification. C Visualizations of the actual (left) and recovered latent space before (middle) and after (right) learning. Predictive DM recovers the true space up to affine transformations.
  • ...and 9 more figures

Theorems & Definitions (12)

  • Theorem 1: Identifiability of predictive distribution matching under a Gaussian predictor
  • Theorem 2
  • Definition 1: Data generation process
  • Definition 2
  • Theorem \ref{th:ident} (restated)
  • Corollary 1
  • Remark 1
  • Theorem \ref{th:identCPC} (restated)
  • Definition 3: Piecewise invertible
  • Definition 4: Folding entropy
  • ...and 2 more