JEPA for RL: Investigating Joint-Embedding Predictive Architectures for Reinforcement Learning
Tristan Kenneweg, Philip Kenneweg, Barbara Hammer
TL;DR
Problem: RL from image observations is slow due to high-dimensional inputs. Approach: adapt the Joint-Embedding Predictive Architecture (JEPA) to RL using a vision transformer to learn compact, predictive embeddings. Contributions: (1) a concrete JEPA-RL pipeline with an x-encoder processing recent frames and a momentum-updated y-encoder, (2) collapse-mitigation strategies including gradient propagation of actor/critic losses and a variance regularizer, and (3) empirical validation on Cart Pole showing strong gains when JEPA is used with RL updates and stability with regularization. Significance: demonstrates JEPA as a promising direction for self-supervised representation learning for RL from pixel inputs.
Abstract
Joint-Embedding Predictive Architectures (JEPA) have recently become popular as promising architectures for self-supervised learning. Vision transformers have been trained using JEPA to produce embeddings from images and videos, which have been shown to be highly suitable for downstream tasks like classification and segmentation. In this paper, we show how to adapt the JEPA architecture to reinforcement learning from images. We discuss model collapse, show how to prevent it, and provide exemplary data on the classical Cart Pole task.
