Learning and Leveraging World Models in Visual Representation Learning
Quentin Garrido, Mahmoud Assran, Nicolas Ballas, Adrien Bardes, Laurent Najman, Yann LeCun
TL;DR
The paper introduces Image World Models (IWM) within the Joint Embedding Predictive Architecture (JEPA) to learn latent-space transformations that go beyond masked image modeling. It identifies conditioning, transformation difficulty, and predictor capacity as key factors for successful IWMs and demonstrates that finetuning the predictor on top of a frozen encoder can match or surpass encoder finetuning with efficiency gains, while enabling multi-task learning. IWMs enable a controllable spectrum of representations from invariant to equivariant, allowing flexible downstream performance on classification and segmentation tasks. The work presents practical guidelines for building reusable world models in visual representation learning and highlights their potential to bridge contrastive and MIM paradigms with efficient adaptation across tasks.
Abstract
Joint-Embedding Predictive Architecture (JEPA) has emerged as a promising self-supervised approach that learns by leveraging a world model. While previously limited to predicting missing parts of an input, we explore how to generalize the JEPA prediction task to a broader set of corruptions. We introduce Image World Models, an approach that goes beyond masked image modeling and learns to predict the effect of global photometric transformations in latent space. We study the recipe of learning performant IWMs and show that it relies on three key aspects: conditioning, prediction difficulty, and capacity. Additionally, we show that the predictive world model learned by IWM can be adapted through finetuning to solve diverse tasks; a fine-tuned IWM world model matches or surpasses the performance of previous self-supervised methods. Finally, we show that learning with an IWM allows one to control the abstraction level of the learned representations, learning invariant representations such as contrastive methods, or equivariant representations such as masked image modelling.
