Table of Contents
Fetching ...

JEPA for RL: Investigating Joint-Embedding Predictive Architectures for Reinforcement Learning

Tristan Kenneweg, Philip Kenneweg, Barbara Hammer

TL;DR

Problem: RL from image observations is slow due to high-dimensional inputs. Approach: adapt the Joint-Embedding Predictive Architecture (JEPA) to RL using a vision transformer to learn compact, predictive embeddings. Contributions: (1) a concrete JEPA-RL pipeline with an x-encoder processing recent frames and a momentum-updated y-encoder, (2) collapse-mitigation strategies including gradient propagation of actor/critic losses and a variance regularizer, and (3) empirical validation on Cart Pole showing strong gains when JEPA is used with RL updates and stability with regularization. Significance: demonstrates JEPA as a promising direction for self-supervised representation learning for RL from pixel inputs.

Abstract

Joint-Embedding Predictive Architectures (JEPA) have recently become popular as promising architectures for self-supervised learning. Vision transformers have been trained using JEPA to produce embeddings from images and videos, which have been shown to be highly suitable for downstream tasks like classification and segmentation. In this paper, we show how to adapt the JEPA architecture to reinforcement learning from images. We discuss model collapse, show how to prevent it, and provide exemplary data on the classical Cart Pole task.

JEPA for RL: Investigating Joint-Embedding Predictive Architectures for Reinforcement Learning

TL;DR

Problem: RL from image observations is slow due to high-dimensional inputs. Approach: adapt the Joint-Embedding Predictive Architecture (JEPA) to RL using a vision transformer to learn compact, predictive embeddings. Contributions: (1) a concrete JEPA-RL pipeline with an x-encoder processing recent frames and a momentum-updated y-encoder, (2) collapse-mitigation strategies including gradient propagation of actor/critic losses and a variance regularizer, and (3) empirical validation on Cart Pole showing strong gains when JEPA is used with RL updates and stability with regularization. Significance: demonstrates JEPA as a promising direction for self-supervised representation learning for RL from pixel inputs.

Abstract

Joint-Embedding Predictive Architectures (JEPA) have recently become popular as promising architectures for self-supervised learning. Vision transformers have been trained using JEPA to produce embeddings from images and videos, which have been shown to be highly suitable for downstream tasks like classification and segmentation. In this paper, we show how to adapt the JEPA architecture to reinforcement learning from images. We discuss model collapse, show how to prevent it, and provide exemplary data on the classical Cart Pole task.

Paper Structure

This paper contains 4 sections, 3 equations, 3 figures.

Figures (3)

  • Figure 1: Overview of JEPA as proposed by Yann LeCun.
  • Figure 2: Overview of our JEPA pipeline as adapted to reinforcement learning. We feed all patch embedding of frames $f_{t-2}$ to $f_{t}$ into the x-encoder $V(\theta,x)$.
  • Figure 3: Average episodic return over the first 100k environment steps for all four configurations. Each graph shows the accumulated results of 5 runs.