Why and How Auxiliary Tasks Improve JEPA Representations
Jiacan Yu, Siyi Chen, Mingrui Liu, Nono Horiuchi, Vladimir Braverman, Zicheng Xu, Dan Haramati, Randall Balestriero
TL;DR
This work analyzes Joint-Embedding Predictive Architecture (JEPA) augmented with an auxiliary regression head trained jointly with latent dynamics. It introduces a simple, practical model (P-JEPA) with a joint loss $\mathcal{L}=\mathcal{L}_{dyn}+c_p\mathcal{L}_p$ and proves a No Unhealthy Representation Collapse theorem in deterministic MDPs, showing that non-bisimilar observations must map to distinct latent representations when both losses vanish. In a counting environment, the authors demonstrate that using the reward as the auxiliary target yields nine distinct latent clusters corresponding to object counts 0–8, while other ablations produce weaker structure or collapse, supporting the theory. The findings offer a practical guideline: choose auxiliary objectives that encode the phenomenon of interest to enforce the right equivalence relations and improve JEPA encoders, with implications for model-based RL methods like TD-MPC2.
Abstract
Joint-Embedding Predictive Architecture (JEPA) is increasingly used for visual representation learning and as a component in model-based RL, but its behavior remains poorly understood. We provide a theoretical characterization of a simple, practical JEPA variant that has an auxiliary regression head trained jointly with latent dynamics. We prove a No Unhealthy Representation Collapse theorem: in deterministic MDPs, if training drives both the latent-transition consistency loss and the auxiliary regression loss to zero, then any pair of non-equivalent observations, i.e., those that do not have the same transition dynamics or auxiliary value, must map to distinct latent representations. Thus, the auxiliary task anchors which distinctions the representation must preserve. Controlled ablations in a counting environment corroborate the theory and show that training the JEPA model jointly with the auxiliary head generates a richer representation than training them separately. Our work indicates a path to improve JEPA encoders: training them with an auxiliary function that, together with the transition dynamics, encodes the right equivalence relations.
