Table of Contents
Fetching ...

Why and How Auxiliary Tasks Improve JEPA Representations

Jiacan Yu, Siyi Chen, Mingrui Liu, Nono Horiuchi, Vladimir Braverman, Zicheng Xu, Dan Haramati, Randall Balestriero

TL;DR

This work analyzes Joint-Embedding Predictive Architecture (JEPA) augmented with an auxiliary regression head trained jointly with latent dynamics. It introduces a simple, practical model (P-JEPA) with a joint loss $\mathcal{L}=\mathcal{L}_{dyn}+c_p\mathcal{L}_p$ and proves a No Unhealthy Representation Collapse theorem in deterministic MDPs, showing that non-bisimilar observations must map to distinct latent representations when both losses vanish. In a counting environment, the authors demonstrate that using the reward as the auxiliary target yields nine distinct latent clusters corresponding to object counts 0–8, while other ablations produce weaker structure or collapse, supporting the theory. The findings offer a practical guideline: choose auxiliary objectives that encode the phenomenon of interest to enforce the right equivalence relations and improve JEPA encoders, with implications for model-based RL methods like TD-MPC2.

Abstract

Joint-Embedding Predictive Architecture (JEPA) is increasingly used for visual representation learning and as a component in model-based RL, but its behavior remains poorly understood. We provide a theoretical characterization of a simple, practical JEPA variant that has an auxiliary regression head trained jointly with latent dynamics. We prove a No Unhealthy Representation Collapse theorem: in deterministic MDPs, if training drives both the latent-transition consistency loss and the auxiliary regression loss to zero, then any pair of non-equivalent observations, i.e., those that do not have the same transition dynamics or auxiliary value, must map to distinct latent representations. Thus, the auxiliary task anchors which distinctions the representation must preserve. Controlled ablations in a counting environment corroborate the theory and show that training the JEPA model jointly with the auxiliary head generates a richer representation than training them separately. Our work indicates a path to improve JEPA encoders: training them with an auxiliary function that, together with the transition dynamics, encodes the right equivalence relations.

Why and How Auxiliary Tasks Improve JEPA Representations

TL;DR

This work analyzes Joint-Embedding Predictive Architecture (JEPA) augmented with an auxiliary regression head trained jointly with latent dynamics. It introduces a simple, practical model (P-JEPA) with a joint loss and proves a No Unhealthy Representation Collapse theorem in deterministic MDPs, showing that non-bisimilar observations must map to distinct latent representations when both losses vanish. In a counting environment, the authors demonstrate that using the reward as the auxiliary target yields nine distinct latent clusters corresponding to object counts 0–8, while other ablations produce weaker structure or collapse, supporting the theory. The findings offer a practical guideline: choose auxiliary objectives that encode the phenomenon of interest to enforce the right equivalence relations and improve JEPA encoders, with implications for model-based RL methods like TD-MPC2.

Abstract

Joint-Embedding Predictive Architecture (JEPA) is increasingly used for visual representation learning and as a component in model-based RL, but its behavior remains poorly understood. We provide a theoretical characterization of a simple, practical JEPA variant that has an auxiliary regression head trained jointly with latent dynamics. We prove a No Unhealthy Representation Collapse theorem: in deterministic MDPs, if training drives both the latent-transition consistency loss and the auxiliary regression loss to zero, then any pair of non-equivalent observations, i.e., those that do not have the same transition dynamics or auxiliary value, must map to distinct latent representations. Thus, the auxiliary task anchors which distinctions the representation must preserve. Controlled ablations in a counting environment corroborate the theory and show that training the JEPA model jointly with the auxiliary head generates a richer representation than training them separately. Our work indicates a path to improve JEPA encoders: training them with an auxiliary function that, together with the transition dynamics, encodes the right equivalence relations.

Paper Structure

This paper contains 18 sections, 6 theorems, 17 equations, 2 figures.

Key Result

Theorem 1

Let $\mathcal{M}$ be a deterministic MDP, $p$ be a function of observations, and a P-JEPA model be well-trained: $T_\psi(E_\phi(o),a)=E_\phi\bigl(f(o,a)\bigr)$ and $P_\theta(E_\phi(o)) = p(o)$ for all $o$ and $a$. Then any pair of observations that is not in the largest bisimulation over $\mathcal{M

Figures (2)

  • Figure 1: Architecture of P-JEPA. The pentagon is the JEPA core: $E_\phi$ is the encoder; $T_\psi$ is the latent transition model. $P_\theta(z_t)$ regresses to an auxiliary function of observations $p$. $p$ can be the reward $r$ or a randomly initialized neural network; see Sec. \ref{['sec:knowledge-and-mbrl']}. $E_\phi$ is updated by both the dynamics loss and the auxiliary loss; no target/EMA NEURIPS2020_f3ada80d encoder is used.
  • Figure 2: Each row: left—PCA of embeddings of 256 randomly chosen observations; different colors correspond to different object counts; middle— pairwise $\ell_2$ distances between the same 256 embeddings, with samples sorted by object count and red grid lines marking count boundaries; right—example observations (left of each pair) and decoder outputs (right, normalized for better contrast). Top row: P-JEPA with reward auxiliary: PCA shows nine clear clusters; the diagonal blocks in the heatmap are darker than off-diagonal blocks, indicating separation according to object count; reconstructions discard shape/color/position. Middle row: P-JEPA with 256-dimensional random auxiliary: no count structure; in the heatmap, the diagonal blocks are as bright as off-diagonal blocks, showing distances within the same object count are comparable to those across different counts; decoder recovers position and partial color/shape information. Bottom row: Encoder receives gradients only from reward loss: representation space shows only coarse separation; the heatmap exhibits only coarse block structure, roughly grouping counts into three sets (0–2, 3–5, 6–8); decoder cannot recover color/shape/position information.

Theorems & Definitions (14)

  • Definition 1: Largest bisimulation
  • Theorem 1: No Unhealthy Representation Collapse
  • Definition 2: Bisimulation for deterministic MDP
  • Definition 3: Largest bisimulation
  • Theorem 2: No Unhealthy Representation Collapse
  • proof
  • Definition 4: Empirical largest bisimulation
  • Theorem 3: Empirical No Unhealthy Representation Collapse
  • proof
  • Proposition 1: $9$-way partition is a bisimulation
  • ...and 4 more