Table of Contents
Fetching ...

CrossJEPA: Cross-Modal Joint-Embedding Predictive Architecture for Efficient 3D Representation Learning from 2D Images

Avishka Perera, Kumal Hewagamage, Saeedha Nazar, Kavishka Abeywardana, Hasitha Gallella, Ranga Rodrigo, Mohamed Afham

TL;DR

CrossJEPA introduces a lightweight, cross-modal Joint-Embedding Predictive Architecture that transfers knowledge from a frozen 2D image foundation model to 3D point clouds without relying on masking. By predicting rendered-view image embeddings from 3D points with pose-based latent conditioning and caching embeddings, it delivers strong 3D representations with far fewer parameters and training hours than prior cross-modal SSL methods. The approach achieves state-of-the-art linear probing on ModelNet40 and competitive results on real-world ScanObjectNN, while offering substantial data efficiency and practical training speedups. The work also provides information-theoretic and predictive-coding justifications for its design and establishes a foundation for efficient, resource-conscious 3D representation learning via cross-modal knowledge distillation.

Abstract

Image-to-point cross-modal learning has emerged to address the scarcity of large-scale 3D datasets in 3D representation learning. However, current methods that leverage 2D data often result in large, slow-to-train models, making them computationally expensive and difficult to deploy in resource-constrained environments. The architecture design of such models is therefore critical, determining their performance, memory footprint, and compute efficiency. The Joint-embedding Predictive Architecture (JEPA) has gained wide popularity in self-supervised learning for its simplicity and efficiency, but has been under-explored in cross-modal settings, partly due to the misconception that masking is intrinsic to JEPA. In this light, we propose CrossJEPA, a simple Cross-modal Joint Embedding Predictive Architecture that harnesses the knowledge of an image foundation model and trains a predictor to infer embeddings of specific rendered 2D views from corresponding 3D point clouds, thereby introducing a JEPA-style pretraining strategy beyond masking. By conditioning the predictor on cross-domain projection information, CrossJEPA purifies the supervision signal from semantics exclusive to the target domain. We further exploit the frozen teacher design with a one-time target embedding caching mechanism, yielding amortized efficiency. CrossJEPA achieves a new state-of-the-art in linear probing on the synthetic ModelNet40 (94.2%) and the real-world ScanObjectNN (88.3%) benchmarks, using only 14.1M pretraining parameters (8.5M in the point encoder), and about 6 pretraining hours on a standard single GPU. These results position CrossJEPA as a performant, memory-efficient, and fast-to-train framework for 3D representation learning via knowledge distillation. We analyze CrossJEPA intuitively, theoretically, and empirically, and extensively ablate our design choices. Code will be made available.

CrossJEPA: Cross-Modal Joint-Embedding Predictive Architecture for Efficient 3D Representation Learning from 2D Images

TL;DR

CrossJEPA introduces a lightweight, cross-modal Joint-Embedding Predictive Architecture that transfers knowledge from a frozen 2D image foundation model to 3D point clouds without relying on masking. By predicting rendered-view image embeddings from 3D points with pose-based latent conditioning and caching embeddings, it delivers strong 3D representations with far fewer parameters and training hours than prior cross-modal SSL methods. The approach achieves state-of-the-art linear probing on ModelNet40 and competitive results on real-world ScanObjectNN, while offering substantial data efficiency and practical training speedups. The work also provides information-theoretic and predictive-coding justifications for its design and establishes a foundation for efficient, resource-conscious 3D representation learning via cross-modal knowledge distillation.

Abstract

Image-to-point cross-modal learning has emerged to address the scarcity of large-scale 3D datasets in 3D representation learning. However, current methods that leverage 2D data often result in large, slow-to-train models, making them computationally expensive and difficult to deploy in resource-constrained environments. The architecture design of such models is therefore critical, determining their performance, memory footprint, and compute efficiency. The Joint-embedding Predictive Architecture (JEPA) has gained wide popularity in self-supervised learning for its simplicity and efficiency, but has been under-explored in cross-modal settings, partly due to the misconception that masking is intrinsic to JEPA. In this light, we propose CrossJEPA, a simple Cross-modal Joint Embedding Predictive Architecture that harnesses the knowledge of an image foundation model and trains a predictor to infer embeddings of specific rendered 2D views from corresponding 3D point clouds, thereby introducing a JEPA-style pretraining strategy beyond masking. By conditioning the predictor on cross-domain projection information, CrossJEPA purifies the supervision signal from semantics exclusive to the target domain. We further exploit the frozen teacher design with a one-time target embedding caching mechanism, yielding amortized efficiency. CrossJEPA achieves a new state-of-the-art in linear probing on the synthetic ModelNet40 (94.2%) and the real-world ScanObjectNN (88.3%) benchmarks, using only 14.1M pretraining parameters (8.5M in the point encoder), and about 6 pretraining hours on a standard single GPU. These results position CrossJEPA as a performant, memory-efficient, and fast-to-train framework for 3D representation learning via knowledge distillation. We analyze CrossJEPA intuitively, theoretically, and empirically, and extensively ablate our design choices. Code will be made available.

Paper Structure

This paper contains 43 sections, 54 equations, 10 figures, 17 tables.

Figures (10)

  • Figure 1: ModelNet40 Linear Evaluation. Pretraining time on a single NVIDIA RTX 4090 versus best accuracy with an SVM linear classifier on ModelNet40 modelnet40. We compare CrossJEPA with recent cross-modal and unimodal point baselines from 2022 onward, covering different SSL architectures: generative i2p-maeactpointm2ae, joint-embedding architecture (JEA) crosspoint, hybrid of generative and JEA recon, and JEPA pointjepa. Red markers denote cross-modal methods, while blue markers denote unimodal point methods. CrossJEPA achieves the best trade-off between accuracy, compute, and memory, producing rich point cloud representations with significantly less pretraining compute and fewer parameters. Except for CrossPoint crosspoint, all other methods use a standard transformer architecture. The time required to reach the best accuracy, along with the total number of learnable parameters, is reported (see Table \ref{['table:linear_probing']}).
  • Figure 2: CrossJEPA: CrossJEPA’s training objective is to learn generic features by predicting an image representation $s_i$ from a point cloud representation $s_p$. This is achieved using a distillation architecture that leverages the physical relationship between a point cloud $p$ and its 2D render $i$. The student's learnable point encoder $P$ generates $s_p$, while a frozen, pretrained image encoder $I$ serves as the teacher, providing the target $s_i$. A predictor network is trained to map $s_p$ to specific $s_i^{[z]}$. The image semantics that are unavailable in the point representation are explicitly provided as latent information $z$ to the predictor to make the predictions more definitive. Note that the image encoder is frozen; hence, we employ precomputed image embeddings and color histograms as depicted Fig. in \ref{['fig:xjepa-caching']} in the Supplementary.
  • Figure 3: Decay of the learning loss at different latent information configurations. This diagram depicts how the discrepancy between the image representation (target) and the model prediction reduces with epochs. By increasing the latent information, the training loss reduced drastically.
  • Figure 4: Common SSL Architectures. (a) Generative architectures directly predict $\hat{y}$ from the representation $s_x$ of input $x$ (where $x$ is often a corrupted version of $y$). (b) Joint-Embedding Architectures (JEA) minimize the embedding distance between $s_x$ and $s_y$ for signals $x$ and $y$, typically through invariance-based learning. (c) Joint-Embedding Predictive Architectures (JEPA) use a predictor conditioned on minimal latent information $z$ to predict the abstract representation $s_y$ instead of reconstructing $y$, thereby focusing on semantic structure rather than modality-specific details. Conventional JEPA pretraining relies on masking a single modality $A$ by selecting random patches $[i,j]$ and predicting the embeddings of the masked regions, which limits JEPA to unimodal setups since mask correspondence across modalities introduces inconsistencies. (d) CrossJEPA moves beyond masking through a latent view prediction objective. The input modality $A$ (e.g., 3D point cloud) is projected into modality $B$ (e.g., 2D images) using cross-domain projection parameters, and the predictor learns to infer the representation of $y$ in modality $B$ conditioned on latent information $\gamma$ that is exclusively of modality $B$. This yields a cleaner learning signal, removes modality-$B$-specific information originated in the frozen $y-encoder$ (modality $B$ teacher) from the gradient path, and enables the point student to focus on mutual cross-modal information to learn richer 3D representations.
  • Figure 5: Cross-modal Architecture Selection We explore three different cross-modal architectures: (a) $P_2I$ and $P2P$ learning objectives. (b) $P_2I$ and $I2I$ learning objectives. (c) $P2I$ leaning objective alone. The $P_2I$ learning objective attempts to predict the representations of views of an object specified by a certain view index $\gamma$ by referring to a point cloud representation. $P_2P$ learning objective is the same as pointjepa: the representations of masked point groups are predicted by referring to a context point group. $I_2I$ learning objective attempts to predict the representation of a target view ($y$) by referring to a context view ($x$).
  • ...and 5 more figures