Table of Contents
Fetching ...

How JEPA Avoids Noisy Features: The Implicit Bias of Deep Linear Self Distillation Networks

Etai Littwin, Omid Saremi, Madhu Advani, Vimal Thilak, Preetum Nakkiran, Chen Huang, Joshua Susskind

TL;DR

The paper analyzes the implicit bias of JEPA versus MAE within deep linear networks to explain why JEPA often emphasizes semantic, low-variance, high-influence features. By solving the training dynamics analytically, it shows JEPA induces a depth-dependent bias toward features with large regression coefficients $\rho$, while MAE emphasizes highly covariant directions with covariance $\lambda$. The key contributions are exact ODE characterizations of JEPA/MAE dynamics, explicit critical-time formulas, and empirical validation showing depth amplifies the JEPA bias. This work deepens understanding of latent-space prediction in SSL and informs when JEPA-based methods may preferentially extract meaningful, robust representations for downstream tasks.

Abstract

Two competing paradigms exist for self-supervised learning of data representations. Joint Embedding Predictive Architecture (JEPA) is a class of architectures in which semantically similar inputs are encoded into representations that are predictive of each other. A recent successful approach that falls under the JEPA framework is self-distillation, where an online encoder is trained to predict the output of the target encoder, sometimes using a lightweight predictor network. This is contrasted with the Masked AutoEncoder (MAE) paradigm, where an encoder and decoder are trained to reconstruct missing parts of the input in the data space rather, than its latent representation. A common motivation for using the JEPA approach over MAE is that the JEPA objective prioritizes abstract features over fine-grained pixel information (which can be unpredictable and uninformative). In this work, we seek to understand the mechanism behind this empirical observation by analyzing the training dynamics of deep linear models. We uncover a surprising mechanism: in a simplified linear setting where both approaches learn similar representations, JEPAs are biased to learn high-influence features, i.e., features characterized by having high regression coefficients. Our results point to a distinct implicit bias of predicting in latent space that may shed light on its success in practice.

How JEPA Avoids Noisy Features: The Implicit Bias of Deep Linear Self Distillation Networks

TL;DR

The paper analyzes the implicit bias of JEPA versus MAE within deep linear networks to explain why JEPA often emphasizes semantic, low-variance, high-influence features. By solving the training dynamics analytically, it shows JEPA induces a depth-dependent bias toward features with large regression coefficients , while MAE emphasizes highly covariant directions with covariance . The key contributions are exact ODE characterizations of JEPA/MAE dynamics, explicit critical-time formulas, and empirical validation showing depth amplifies the JEPA bias. This work deepens understanding of latent-space prediction in SSL and informs when JEPA-based methods may preferentially extract meaningful, robust representations for downstream tasks.

Abstract

Two competing paradigms exist for self-supervised learning of data representations. Joint Embedding Predictive Architecture (JEPA) is a class of architectures in which semantically similar inputs are encoded into representations that are predictive of each other. A recent successful approach that falls under the JEPA framework is self-distillation, where an online encoder is trained to predict the output of the target encoder, sometimes using a lightweight predictor network. This is contrasted with the Masked AutoEncoder (MAE) paradigm, where an encoder and decoder are trained to reconstruct missing parts of the input in the data space rather, than its latent representation. A common motivation for using the JEPA approach over MAE is that the JEPA objective prioritizes abstract features over fine-grained pixel information (which can be unpredictable and uninformative). In this work, we seek to understand the mechanism behind this empirical observation by analyzing the training dynamics of deep linear models. We uncover a surprising mechanism: in a simplified linear setting where both approaches learn similar representations, JEPAs are biased to learn high-influence features, i.e., features characterized by having high regression coefficients. Our results point to a distinct implicit bias of predicting in latent space that may shed light on its success in practice.
Paper Structure (26 sections, 22 theorems, 138 equations, 5 figures)

This paper contains 26 sections, 22 theorems, 138 equations, 5 figures.

Key Result

Theorem 4.2

Suppose $\{W^a\}_{a=1\cdots L}$ and $V$ are initialized according to ass:init. Let $\bar{w}_i=\|\bar{W}e_i\|$, where $e_i$s are the standard basis. Furthermore, assume the JEPA objective in eqn:lin is optimized using gradient flow according to eqn:gf. Then, we have Similarly, the MAE objective eqn:lin is optimized using gradient flow according to eqn:gf yielding:

Figures (5)

  • Figure 1: Deep linear model trained using the MAE and JEPA objectives \ref{['eqn:lin']}. Features indices ($x$-axis) are organized such that the covariance $\lambda_i$ is monotonically increasing and $\rho_i$ is (a) - (c) monotonically increasing or (d) - (f) monotonically decreasing. (b),(c): Both objectives learn features in the same order given distribution 1. In (e),(f) the MAE objective maintains the same learning order as in (b) on distribution 2, however, the JEPA objective reverses the learning order, due to sensitivity to $\rho_i$.
  • Figure 2: Simulations of the JEPA and MAE equivalent ODEs (\ref{['eqn:dynj', 'eqn:dynm']}). Each curve represents a numerical simulation of the corresponding ODE, for different values of $\rho, \lambda$. (a), (d) darker curves correspond to higher $\lambda$ and $\rho = 1$. (b), (e) darker curves correspond to higher $\rho$ and $\lambda = 1$. As can be seen, both objectives exhibit greedy learning dynamics with respect to $\lambda$, however, only JEPA exhibits greedy dynamics with respect to $\rho$. (c), (f) darker curves correspond to higher $\lambda$ but lower $\rho$. In this case, the order of learning is inverted between the objectives due to the different trends in $\rho,\lambda$.
  • Figure 3: Temporal model \ref{['eq:gen_temp']}: (a) $v_1, v_2$, (b) temporal dynamics of flickering $u$, autocorrelation $\gamma_1 = 0.99$ and $\gamma_2 = 0.95$, (c) $\hat{\Sigma}^{xx}$ empirical covariance with 500k samples, (d) $\hat{\Sigma}^{xy}$ empirical correlation, (e) predictions versus simulations of parameters $\lambda$ and $\rho$ varying $\log_{10}(T)$. $\lambda$ decreases from feature $1$ to $2$ whereas $\rho$ increase because noise added standard deviation decreases from $1$ to $0.5$, (f) simultaneous diagonalizability measured by mean of squared error in off-diagonal elements when attempting to diagonalize $\hat{\Sigma}^{xy}$ using the eigenbasis of $\hat{\Sigma}^{xx}$. Note (e),(f) show 2 standard deviation with $10$ runs.
  • Figure 4: Simulations of the JEPA and MAE equivalent ODEs for $L=2,5,6,7$ (\ref{['eqn:dynj', 'eqn:dynm']}). The covariance $\lambda$ is fixed to $1$ across all curves, and darker curves correspond to a higher $\rho$. As evident, only in the case of the JEPA objective, deeper encoders induce a more pronounced incremental learning of features with respect to $\rho$.
  • Figure 5: Deep linear networks trained on Gaussian data. The left most column represents the values of $\lambda,\rho$ used to generate the data. All networks were initialized using standard gaussian initialization with default scale, and the encoder depth is fixed to $L=5$. The order of feature learning is dictated by $\rho$ for the JEPA objective, and $\lambda$ for the MAE objective.

Theorems & Definitions (36)

  • Theorem 4.2: ODE Equivalence
  • Corollary 4.2
  • Theorem 4.3: JEPA dynamics
  • Theorem 4.4: JEPA critical time
  • Theorem 4.5: MAE dynamics
  • Theorem 4.6: MAE critical time
  • Corollary 4.6
  • Theorem B.1: ODE Equivalence
  • proof
  • Lemma B.1
  • ...and 26 more