Table of Contents
Fetching ...

Var-JEPA: A Variational Formulation of the Joint-Embedding Predictive Architecture -- Bridging Predictive and Generative Self-Supervised Learning

Moritz Gögl, Christopher Yau

Abstract

The Joint-Embedding Predictive Architecture (JEPA) is often seen as a non-generative alternative to likelihood-based self-supervised learning, emphasizing prediction in representation space rather than reconstruction in observation space. We argue that the resulting separation from probabilistic generative modeling is largely rhetorical rather than structural: the canonical JEPA design, coupled encoders with a context-to-target predictor, mirrors the variational posteriors and learned conditional priors obtained when variational inference is applied to a particular class of coupled latent-variable models, and standard JEPA can be viewed as a deterministic specialization in which regularization is imposed via architectural and training heuristics rather than an explicit likelihood. Building on this view, we derive the Variational JEPA (Var-JEPA), which makes the latent generative structure explicit by optimizing a single Evidence Lower Bound (ELBO). This yields meaningful representations without ad-hoc anti-collapse regularizers and allows principled uncertainty quantification in the latent space. We instantiate the framework for tabular data (Var-T-JEPA) and achieve strong representation learning and downstream performance, consistently improving over T-JEPA while remaining competitive with strong raw-feature baselines.

Var-JEPA: A Variational Formulation of the Joint-Embedding Predictive Architecture -- Bridging Predictive and Generative Self-Supervised Learning

Abstract

The Joint-Embedding Predictive Architecture (JEPA) is often seen as a non-generative alternative to likelihood-based self-supervised learning, emphasizing prediction in representation space rather than reconstruction in observation space. We argue that the resulting separation from probabilistic generative modeling is largely rhetorical rather than structural: the canonical JEPA design, coupled encoders with a context-to-target predictor, mirrors the variational posteriors and learned conditional priors obtained when variational inference is applied to a particular class of coupled latent-variable models, and standard JEPA can be viewed as a deterministic specialization in which regularization is imposed via architectural and training heuristics rather than an explicit likelihood. Building on this view, we derive the Variational JEPA (Var-JEPA), which makes the latent generative structure explicit by optimizing a single Evidence Lower Bound (ELBO). This yields meaningful representations without ad-hoc anti-collapse regularizers and allows principled uncertainty quantification in the latent space. We instantiate the framework for tabular data (Var-T-JEPA) and achieve strong representation learning and downstream performance, consistently improving over T-JEPA while remaining competitive with strong raw-feature baselines.
Paper Structure (88 sections, 45 equations, 6 figures, 7 tables, 1 algorithm)

This paper contains 88 sections, 45 equations, 6 figures, 7 tables, 1 algorithm.

Figures (6)

  • Figure 1: Comparison between JEPA (left) and Var-JEPA (right). Var-JEPA extends JEPA by replacing deterministic encoders and predictor with probabilistic distributions, and adding generative networks (decoders) to enable variational learning under a unified ELBO.
  • Figure 2: Epoch-wise distribution diagnostics for selected experiments. We show how aggregated KL divergences, SIGReg-MSE, isotropy metrics, and conditional-prior coupling evolve during training.
  • Figure 3: Uncertainty quantification on the MNIST (A) and SIM (B) datasets. Left: risk–coverage curve induced by abstaining on samples with highest latent uncertainty. Middle: standardized latent uncertainty versus simulated uncertainty. Right: ROC curve for detecting high-ambiguity samples from latent uncertainty.
  • Figure 4: Architectural overview of Var-T-JEPA for tabular data. The model implements the theoretical framework through specialized components: (1) the context encoder processes masked tabular features to produce variational context representations $q_\phi(s_x^{(j)}|x)$; (2) the auxiliary encoder infers predictive latents $q_\phi(z|s_x)$; (3) the target posterior prepends context and auxiliary latents as special tokens before feature processing $q_\phi(s_{y_k}^{(j)}|s_x,z,w;m_{\text{trg}}^{(k)})$; (4) the predictor generates target representations $p_\theta(s_{y_k}^{(j)}|s_x,z;m_{\text{trg}}^{(k)})$, and (5) the decoders reconstruct original tabular features from latent representations.
  • Figure 5: Visualization of MNIST image corruption under different values of the uncertainty score $u_i$.
  • ...and 1 more figures