Table of Contents
Fetching ...

Gaussian Joint Embeddings For Self-Supervised Representation Learning

Yongchao Huang

Abstract

Self-supervised representation learning often relies on deterministic predictive architectures to align context and target views in latent space. While effective in many settings, such methods are limited in genuinely multi-modal inverse problems, where squared-loss prediction collapses towards conditional averages, and they frequently depend on architectural asymmetries to prevent representation collapse. In this work, we propose a probabilistic alternative based on generative joint modeling. We introduce Gaussian Joint Embeddings (GJE) and its multi-modal extension, Gaussian Mixture Joint Embeddings (GMJE), which model the joint density of context and target representations and replace black-box prediction with closed-form conditional inference under an explicit probabilistic model. This yields principled uncertainty estimates and a covariance-aware objective for controlling latent geometry. We further identify a failure mode of naive empirical batch optimization, which we term the Mahalanobis Trace Trap, and develop several remedies spanning parametric, adaptive, and non-parametric settings, including prototype-based GMJE, conditional Mixture Density Networks (GMJE-MDN), topology-adaptive Growing Neural Gas (GMJE-GNG), and a Sequential Monte Carlo (SMC) memory bank. In addition, we show that standard contrastive learning can be interpreted as a degenerate non-parametric limiting case of the GMJE framework. Experiments on synthetic multi-modal alignment tasks and vision benchmarks show that GMJE recovers complex conditional structure, learns competitive discriminative representations, and defines latent densities that are better suited to unconditional sampling than deterministic or unimodal baselines.

Gaussian Joint Embeddings For Self-Supervised Representation Learning

Abstract

Self-supervised representation learning often relies on deterministic predictive architectures to align context and target views in latent space. While effective in many settings, such methods are limited in genuinely multi-modal inverse problems, where squared-loss prediction collapses towards conditional averages, and they frequently depend on architectural asymmetries to prevent representation collapse. In this work, we propose a probabilistic alternative based on generative joint modeling. We introduce Gaussian Joint Embeddings (GJE) and its multi-modal extension, Gaussian Mixture Joint Embeddings (GMJE), which model the joint density of context and target representations and replace black-box prediction with closed-form conditional inference under an explicit probabilistic model. This yields principled uncertainty estimates and a covariance-aware objective for controlling latent geometry. We further identify a failure mode of naive empirical batch optimization, which we term the Mahalanobis Trace Trap, and develop several remedies spanning parametric, adaptive, and non-parametric settings, including prototype-based GMJE, conditional Mixture Density Networks (GMJE-MDN), topology-adaptive Growing Neural Gas (GMJE-GNG), and a Sequential Monte Carlo (SMC) memory bank. In addition, we show that standard contrastive learning can be interpreted as a degenerate non-parametric limiting case of the GMJE framework. Experiments on synthetic multi-modal alignment tasks and vision benchmarks show that GMJE recovers complex conditional structure, learns competitive discriminative representations, and defines latent densities that are better suited to unconditional sampling than deterministic or unimodal baselines.

Paper Structure

This paper contains 170 sections, 2 theorems, 158 equations, 10 figures, 2 tables, 6 algorithms.

Key Result

Theorem 1

Any sufficiently smooth probability density function $p(\mathbf{z})$ on $\mathbb{R}^d$ can be approximated arbitrarily closely in $L^1$ distance by a Gaussian Mixture Model (GMM) with a finite, sufficiently large number of components covering its whole support.

Figures (10)

  • Figure 1: A comparison of (a) the classic JEPA framework, based on separate encoding and alignment via deterministic latent prediction, and (b) our proposed GJE framework, which models the joint distribution of context and target representations and derives predictions probabilistically. The component with a red outline in (a) represents the deterministic distance loss specific to classic JEPA. 'sg' denotes stop gradient, and $\varepsilon$ represents injected side information (e.g. physics, action conditioning or noise).
  • Figure 2: The general Gaussian Mixture Joint Embeddings (GMJE) framework. Dual encoders map the context and target views to a joint embedding $Z$. This embedding is evaluated against a set of $K$ learnable global mixture parameters ($\mu_k, \Sigma_k, \pi_k$) to natively model the joint probability density $p(z_c, z_t)$ within a symmetric architectural framework.
  • Figure 3: The Information Bottleneck in the GMJE-MDN architecture. To prevent identity collapse, the parameter network must take only the context embedding $z_c$ as input to predict the conditional mixture parameters. These generated parameters are then evaluated against the target embedding $z_t$ inside the conditional mixture model. To prevent representation collapse, this conditional objective is coupled with a marginal loss on $z_c$.
  • Figure 4: Predictive distributions on Dataset A (Separated Branches). (a) Classic JEPA collapses toward the conditional average and fails to recover the three valid branches. Ground-truth samples are shown in gray. (b) The dual-space RBF kernel baseline learns a flexible mean but retains a unimodal Gaussian predictive distribution, producing a broad uncertainty band rather than resolving separate branches. (c) GMJE-EM ($K=1$) fits a single global ellipse and heavily over-smooths the manifold. (d) GMJE-EM ($K=3$) separates the density across multiple components, but the fixed covariance ellipses remain geometrically rigid. (e) GMJE-GNG adaptively places prototypes along the non-linear ridges and captures the manifold topology more faithfully. (f) GMJE-MDN produces the closest qualitative match to the ground-truth conditional density by predicting instance-dependent mixture parameters.
  • Figure 5: Internal parameter routing of GMJE-MDN on Dataset A. (a) The conditional mixing weights remain close to the ground-truth uniform value of $1/3$. (b) The learned conditional means align closely with the three underlying branch functions. (c) The predicted conditional standard deviations remain near the true Gaussian observation noise level $\epsilon = 0.05$.
  • ...and 5 more figures

Theorems & Definitions (4)

  • Theorem 1: Universal approximation property of GMMs li_mixture_1999Huang2025GMA
  • proof
  • Corollary 1: Universal approximation with isotropic Gaussian mixtures mclachlan1988mixturebishop1994mixture
  • proof