Table of Contents
Fetching ...

Understanding self-supervised Learning Dynamics without Contrastive Pairs

Yuandong Tian, Xinlei Chen, Surya Ganguli

TL;DR

The paper tackles why non-contrastive self-supervised learning methods avoid representational collapse without negative samples.It develops a minimal two-layer linear BYOL/SimSiam–like model and derives nonlinear learning dynamics, clarifying the roles of predictor, stop-gradient, EMA, and weight decay.Under simplifying isotropic assumptions, it shows eigenspace alignment between the predictor and the input correlation, an invariant-parabola dynamics, and an EMA-driven curriculum that explains a wide range of ablations and settings.The work then introduces DirectPred, an optimization-free predictor that sets weights directly from input statistics and matches or surpasses gradient-trained predictors on STL-10, CIFAR-10, and ImageNet, demonstrating strong practical impact.

Abstract

While contrastive approaches of self-supervised learning (SSL) learn representations by minimizing the distance between two augmented views of the same data point (positive pairs) and maximizing views from different data points (negative pairs), recent \emph{non-contrastive} SSL (e.g., BYOL and SimSiam) show remarkable performance {\it without} negative pairs, with an extra learnable predictor and a stop-gradient operation. A fundamental question arises: why do these methods not collapse into trivial representations? We answer this question via a simple theoretical study and propose a novel approach, DirectPred, that \emph{directly} sets the linear predictor based on the statistics of its inputs, without gradient training. On ImageNet, it performs comparably with more complex two-layer non-linear predictors that employ BatchNorm and outperforms a linear predictor by $2.5\%$ in 300-epoch training (and $5\%$ in 60-epoch). DirectPred is motivated by our theoretical study of the nonlinear learning dynamics of non-contrastive SSL in simple linear networks. Our study yields conceptual insights into how non-contrastive SSL methods learn, how they avoid representational collapse, and how multiple factors, like predictor networks, stop-gradients, exponential moving averages, and weight decay all come into play. Our simple theory recapitulates the results of real-world ablation studies in both STL-10 and ImageNet. Code is released https://github.com/facebookresearch/luckmatters/tree/master/ssl.

Understanding self-supervised Learning Dynamics without Contrastive Pairs

TL;DR

The paper tackles why non-contrastive self-supervised learning methods avoid representational collapse without negative samples.It develops a minimal two-layer linear BYOL/SimSiam–like model and derives nonlinear learning dynamics, clarifying the roles of predictor, stop-gradient, EMA, and weight decay.Under simplifying isotropic assumptions, it shows eigenspace alignment between the predictor and the input correlation, an invariant-parabola dynamics, and an EMA-driven curriculum that explains a wide range of ablations and settings.The work then introduces DirectPred, an optimization-free predictor that sets weights directly from input statistics and matches or surpasses gradient-trained predictors on STL-10, CIFAR-10, and ImageNet, demonstrating strong practical impact.

Abstract

While contrastive approaches of self-supervised learning (SSL) learn representations by minimizing the distance between two augmented views of the same data point (positive pairs) and maximizing views from different data points (negative pairs), recent \emph{non-contrastive} SSL (e.g., BYOL and SimSiam) show remarkable performance {\it without} negative pairs, with an extra learnable predictor and a stop-gradient operation. A fundamental question arises: why do these methods not collapse into trivial representations? We answer this question via a simple theoretical study and propose a novel approach, DirectPred, that \emph{directly} sets the linear predictor based on the statistics of its inputs, without gradient training. On ImageNet, it performs comparably with more complex two-layer non-linear predictors that employ BatchNorm and outperforms a linear predictor by in 300-epoch training (and in 60-epoch). DirectPred is motivated by our theoretical study of the nonlinear learning dynamics of non-contrastive SSL in simple linear networks. Our study yields conceptual insights into how non-contrastive SSL methods learn, how they avoid representational collapse, and how multiple factors, like predictor networks, stop-gradients, exponential moving averages, and weight decay all come into play. Our simple theory recapitulates the results of real-world ablation studies in both STL-10 and ImageNet. Code is released https://github.com/facebookresearch/luckmatters/tree/master/ssl.

Paper Structure

This paper contains 22 sections, 9 theorems, 81 equations, 10 figures, 9 tables.

Key Result

Lemma 1

BYOL learning dynamics following Eqn. eq:objective:

Figures (10)

  • Figure 1: Two-layer setting with a linear, bias-free predictor.
  • Figure 2: Training BYOL in STL-10 for 100 epochs with EMA. Top row: No symmetric regularization imposed on $W_p$, Bottom row: symmetric regularization on $W_p$. From left to right: (1) Evolvement of eigenvalues for $F$. Since $F$ is PSD and its eigenvalue $s_j$ varies across scales, we plot $\log(s_i)$. We could see some eigenvalues are growing while others are shrinking to zero over training. (2) Similar "step-function" behaviors for the predictor $W_p$. Its negative eigenvalues shrinks towards zero and leading eigenvalues becomes larger. (3) The eigenspace of $F$ and $W_p$ gradually align with each other (Theorem \ref{['thm:alignment-eigenspace']}). For each eigenvector ${\bm{u}}_j$ of $F$, we compute cosine angle (normalized correlation) between ${\bm{u}}_j$ and $W_p {\bm{u}}_j$ to measure alignment. (4)$W_p$ gradually becomes symmetric and PSD during training.
  • Figure 3: State space dynamics in Eqns. \ref{['eq:p']} and \ref{['eq:s']} for no ($\eta=0$) weak ($\eta=0.01$) and strong ($\eta=1$) weight decay at fixed $\tau=1$ and $\alpha_p=1$. Red (green) points indicate stable (unstable) fixed points, blue curves indicate flow lines, and the dashed black curve indicates the parabola $s_j=p_j^2/\alpha_p$.
  • Figure 4: Fixed point of $\dot p_j = -p_j(p_j - p^*_{j-})(p_j - p^*_{j+})$ (see Eqn. \ref{['eq:beta-dyn']}). Stable fixed points are in red, unstable in green and saddle in black. When the weight decay $\eta = 0$, the trivial solution $p_j=0$ is a saddle. When $\eta > 0$, the trivial solution becomes stable near to the origin and initial $p_j$ needs to be large enough to converge to the stable non-collapsed solution $p^*_{j+}$.
  • Figure 5: The role played by weight decay $\eta$ and EMA $\beta$ when applying symmetric regularization on $W_p$ on synthetic experiments simulating decoupled dynamics (Eqn. \ref{['eq:p']}-\ref{['eq:tau']}). The learning rate $\alpha = 0.01$. Both terms boost the eigenvalue of $K(t)$ to above $0$ so that eigen space alignment could happen (Theorem \ref{['thm:alignment-eigenspace']}), but also come with different trade-offs. Here $\beta = 0.4$ so that $\alpha\beta = 0.004 = 1 - \gamma_\mathrm{a}$ where $\gamma_\mathrm{a} = 0.996$ as in BYOL. Top row (Weight Decay $\eta$): A large $\eta$ boost the eigenvalue of $K(t)$ up, but substantially decreases the final converging eigenvalues $p_j$ and $s_j$ (i.e., the final features are not salient), or even drags them to zero (no training happens). Bottom row (EMA $\beta$). A small EMA $\beta$ also boost the eigenvalue of $K(t)$, but the training converges much slower. Here $\eta = 0.04$ so that $\eta\alpha$ equals to the weight decay ($\bar{\eta}=0.0004$) in our STL-10 experiments.
  • ...and 5 more figures

Theorems & Definitions (14)

  • Lemma 1
  • Theorem 1: Weight decay promotes balancing of the predictor and online networks.
  • Theorem 2: The stop-gradient signal is essential for success.
  • Theorem 3: Eigenspace alignment
  • Lemma 1: Dynamics of BYOL/SimSiam
  • proof
  • Theorem 1: Invariance of the Gradient Update
  • proof
  • Lemma 2: Dynamics of a negative definite system
  • proof
  • ...and 4 more