LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures
Hai Huang, Yann LeCun, Randall Balestriero
TL;DR
This work adapts the vision-inspired Joint Embedding Predictive Architecture (JEPA) to large language models by introducing LLM-JEPA, an embedding-space objective that complements the standard next-token loss. The method combines autoregressive text reconstruction with a JEPA term that aligns embeddings of two views (Text and Code) via a predictor, implemented efficiently through a custom two-block attention mask. Empirical results across multiple model families and datasets show consistent gains in finetuning and pretraining, with evidence of structured latent spaces and improved reasoning and generation, as well as resilience to overfitting. To address compute costs, the authors propose Loss Dropout and demonstrate faster convergence and maintainable performance, suggesting LLM-JEPA's practical potential for scalable, general-purpose LLM training and fine-tuning. The approach is formalized by the objective $\mathcal{L}_{\rm LLM-JEPA}$, which adds a cosine-distance-based embedding regularization $d({\rm Pred}(\mathrm{Enc}({\rm Text})), \mathrm{Enc}({\rm Code}))$ scaled by $\lambda$ to the standard $\mathcal{L}_{\rm LLM}$ loss. Overall, LLM-JEPA offers a principled path to improve LLM representations and generative capabilities across diverse tasks, with practical speedups via loss dropout and broader applicability beyond code-like data.
Abstract
Large Language Model (LLM) pretraining, finetuning, and evaluation rely on input-space reconstruction and generative capabilities. Yet, it has been observed in vision that embedding-space training objectives, e.g., with Joint Embedding Predictive Architectures (JEPAs), are far superior to their input-space counterpart. That mismatch in how training is achieved between language and vision opens up a natural question: {\em can language training methods learn a few tricks from the vision ones?} The lack of JEPA-style LLM is a testimony of the challenge in designing such objectives for language. In this work, we propose a first step in that direction where we develop LLM-JEPA, a JEPA based solution for LLMs applicable both to finetuning and pretraining. Thus far, LLM-JEPA is able to outperform the standard LLM training objectives by a significant margin across models, all while being robust to overfiting. Those findings are observed across numerous datasets (NL-RX, GSM8K, Spider, RottenTomatoes) and various models from the Llama3, OpenELM, Gemma2 and Olmo families. Code: https://github.com/rbalestr-lab/llm-jepa.
