Table of Contents
Fetching ...

BiJEPA: Bi-directional Joint Embedding Predictive Architecture for Symmetric Representation Learning

Yongchao Huang

TL;DR

This work proposes a critical norm regularization mechanism on the representation vectors that enforces cycle-consistent predictability between data segments, and addresses the inherent instability of symmetric prediction (representation explosion) by introducing a critical norm regularization mechanism on the representation vectors.

Abstract

Self-Supervised Learning (SSL) has shifted from pixel-level reconstruction to latent space prediction, spearheaded by the Joint Embedding Predictive Architecture (JEPA). While effective, standard JEPA models typically rely on a uni-directional prediction mechanism (e.g. Context $\to$ Target), potentially neglecting the informative signal inherent in the inverse relationship, degrading its performance. In this work, we propose \textbf{BiJEPA}, a \textit{Bi-Directional Joint Embedding Predictive Architecture} that enforces cycle-consistent predictability between data segments. We address the inherent instability of symmetric prediction (representation explosion) by introducing a critical norm regularization mechanism on the representation vectors. We evaluate BiJEPA on three distinct modalities: synthetic periodic signals, chaotic Lorenz attractor trajectories, and high-dimensional image data (MNIST). Our results demonstrate that BiJEPA achieves stable convergence without collapse, captures the semantic structure of chaotic systems, and learns robust temporal and spatial representations capable of generation and generalisation, offering a more holistic approach to representation learning.

BiJEPA: Bi-directional Joint Embedding Predictive Architecture for Symmetric Representation Learning

TL;DR

This work proposes a critical norm regularization mechanism on the representation vectors that enforces cycle-consistent predictability between data segments, and addresses the inherent instability of symmetric prediction (representation explosion) by introducing a critical norm regularization mechanism on the representation vectors.

Abstract

Self-Supervised Learning (SSL) has shifted from pixel-level reconstruction to latent space prediction, spearheaded by the Joint Embedding Predictive Architecture (JEPA). While effective, standard JEPA models typically rely on a uni-directional prediction mechanism (e.g. Context Target), potentially neglecting the informative signal inherent in the inverse relationship, degrading its performance. In this work, we propose \textbf{BiJEPA}, a \textit{Bi-Directional Joint Embedding Predictive Architecture} that enforces cycle-consistent predictability between data segments. We address the inherent instability of symmetric prediction (representation explosion) by introducing a critical norm regularization mechanism on the representation vectors. We evaluate BiJEPA on three distinct modalities: synthetic periodic signals, chaotic Lorenz attractor trajectories, and high-dimensional image data (MNIST). Our results demonstrate that BiJEPA achieves stable convergence without collapse, captures the semantic structure of chaotic systems, and learns robust temporal and spatial representations capable of generation and generalisation, offering a more holistic approach to representation learning.
Paper Structure (62 sections, 13 equations, 6 figures, 3 tables)

This paper contains 62 sections, 13 equations, 6 figures, 3 tables.

Figures (6)

  • Figure 1: JEPA vs. BiJEPA. (A) Standard JEPA learns a uni-directional mapping ($x \to y$). (B) BiJEPA adds a backward predictor $P_{bwd}$ and enforces consistency in both directions (symmetric predictability), learning to map $x \to y$ and $y \to x$ simultaneously using Online ($f_\theta$) and Target ($f_{\bar{\theta}}$) encoders to prevent collapse. SG denotes the Stop-Gradient operation. NB: both the forward and backward loops share the same Online encoder $f_{\theta}$ and Target encoder $f_{\bar{\theta}}$.
  • Figure 2: Impact of stability constraints & architecture. Each row displays: (Left) Training loss; (Middle) Batch forecasting accuracy (Sample Index vs Amplitude); (Right) Single-sample trajectory forecast. (a) Unconstrained BiJEPA diverges due to representation explosion. (b) Expressive BiJEPA achieves stable convergence and high accuracy. (c) Classic JEPA is stable but exhibits noisier loss dynamics and significantly higher forecasting error.
  • Figure 3: Forecasting chaotic dynamics. (Left) Training loss curves showing BiJEPA's superior stability. (Middle) 1-step forecast accuracy for the X-coordinate across a batch of 20 random test samples; BiJEPA (green) tracks the truth (grey) better than Classic JEPA (Red). (Right) 3D phase space reconstruction of a single sample trajectory.
  • Figure 4: Generative hallucinations. The models are given only the Left Half (Input) and must generate the Right Half. BiJEPA (b) produces sharper and more semantically consistent completions than the baseline (a), particularly for difficult digits like '2' and '4'.
  • Figure 5: Dataset samples. Three random sine waves from the validation set showing variations in frequency $\omega \sim \mathcal{U}(0.8, 1.2)$ and phase $\phi \sim \mathcal{U}(0, 2\pi)$.
  • ...and 1 more figures