Table of Contents
Fetching ...

LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures

Hai Huang, Yann LeCun, Randall Balestriero

TL;DR

This work adapts the vision-inspired Joint Embedding Predictive Architecture (JEPA) to large language models by introducing LLM-JEPA, an embedding-space objective that complements the standard next-token loss. The method combines autoregressive text reconstruction with a JEPA term that aligns embeddings of two views (Text and Code) via a predictor, implemented efficiently through a custom two-block attention mask. Empirical results across multiple model families and datasets show consistent gains in finetuning and pretraining, with evidence of structured latent spaces and improved reasoning and generation, as well as resilience to overfitting. To address compute costs, the authors propose Loss Dropout and demonstrate faster convergence and maintainable performance, suggesting LLM-JEPA's practical potential for scalable, general-purpose LLM training and fine-tuning. The approach is formalized by the objective $\mathcal{L}_{\rm LLM-JEPA}$, which adds a cosine-distance-based embedding regularization $d({\rm Pred}(\mathrm{Enc}({\rm Text})), \mathrm{Enc}({\rm Code}))$ scaled by $\lambda$ to the standard $\mathcal{L}_{\rm LLM}$ loss. Overall, LLM-JEPA offers a principled path to improve LLM representations and generative capabilities across diverse tasks, with practical speedups via loss dropout and broader applicability beyond code-like data.

Abstract

Large Language Model (LLM) pretraining, finetuning, and evaluation rely on input-space reconstruction and generative capabilities. Yet, it has been observed in vision that embedding-space training objectives, e.g., with Joint Embedding Predictive Architectures (JEPAs), are far superior to their input-space counterpart. That mismatch in how training is achieved between language and vision opens up a natural question: {\em can language training methods learn a few tricks from the vision ones?} The lack of JEPA-style LLM is a testimony of the challenge in designing such objectives for language. In this work, we propose a first step in that direction where we develop LLM-JEPA, a JEPA based solution for LLMs applicable both to finetuning and pretraining. Thus far, LLM-JEPA is able to outperform the standard LLM training objectives by a significant margin across models, all while being robust to overfiting. Those findings are observed across numerous datasets (NL-RX, GSM8K, Spider, RottenTomatoes) and various models from the Llama3, OpenELM, Gemma2 and Olmo families. Code: https://github.com/rbalestr-lab/llm-jepa.

LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures

TL;DR

This work adapts the vision-inspired Joint Embedding Predictive Architecture (JEPA) to large language models by introducing LLM-JEPA, an embedding-space objective that complements the standard next-token loss. The method combines autoregressive text reconstruction with a JEPA term that aligns embeddings of two views (Text and Code) via a predictor, implemented efficiently through a custom two-block attention mask. Empirical results across multiple model families and datasets show consistent gains in finetuning and pretraining, with evidence of structured latent spaces and improved reasoning and generation, as well as resilience to overfitting. To address compute costs, the authors propose Loss Dropout and demonstrate faster convergence and maintainable performance, suggesting LLM-JEPA's practical potential for scalable, general-purpose LLM training and fine-tuning. The approach is formalized by the objective , which adds a cosine-distance-based embedding regularization scaled by to the standard loss. Overall, LLM-JEPA offers a principled path to improve LLM representations and generative capabilities across diverse tasks, with practical speedups via loss dropout and broader applicability beyond code-like data.

Abstract

Large Language Model (LLM) pretraining, finetuning, and evaluation rely on input-space reconstruction and generative capabilities. Yet, it has been observed in vision that embedding-space training objectives, e.g., with Joint Embedding Predictive Architectures (JEPAs), are far superior to their input-space counterpart. That mismatch in how training is achieved between language and vision opens up a natural question: {\em can language training methods learn a few tricks from the vision ones?} The lack of JEPA-style LLM is a testimony of the challenge in designing such objectives for language. In this work, we propose a first step in that direction where we develop LLM-JEPA, a JEPA based solution for LLMs applicable both to finetuning and pretraining. Thus far, LLM-JEPA is able to outperform the standard LLM training objectives by a significant margin across models, all while being robust to overfiting. Those findings are observed across numerous datasets (NL-RX, GSM8K, Spider, RottenTomatoes) and various models from the Llama3, OpenELM, Gemma2 and Olmo families. Code: https://github.com/rbalestr-lab/llm-jepa.

Paper Structure

This paper contains 23 sections, 3 equations, 9 figures, 15 tables.

Figures (9)

  • Figure 1: LLM-JEPA produces strong fine-tuned models across datasets and models.
  • Figure 2: Left: JEPA applied to NLP tasks that has $Text$ and $Code$, where $Text$ and $Code$ are naturally two views of the same thing. Right: (top): An illustration of the NL-RX-SYNTH dataset, where each sample consists of a description of the regular expression in natural language ($Text$) and the regular expression itself ($Code$). (bottom): The Spider dataset, where $Text$ is the database ID and description of the SQL query and $Code$ is the SQL query itself.
  • Figure 3: left: The top 100 singular values of $\operatorname{Enc}(\operatorname{Text}) - \operatorname{Enc}(\operatorname{Code})$. The curves of LLM-JEPA (ours) are a few magnitudes lower than that of base model and regular fine-tuning, meaning the mapping from $\rm Text$ to $\rm Code$ are confined within a narrow subspace, fostering the nice structure we see in Figure \ref{['fig:structured_rep']}. Right: Losses in fine-tuning with $\mathcal{L}_{\rm LLM}$ loss ($\mathcal{L}_{\rm LLM}$) and $\mathcal{L}_{\rm LLM-JEPA}$ loss ($\mathcal{L}_{\rm LLM-JEPA}$, our method). We measure both the cross-entropy loss for next token prediction ($Loss_{LLM}$, $\mathcal{L}_{\rm LLM}$ in chart) and JEPA prediction loss ($D(\cdot, \cdot)$, pred in chart), although the latter does not contribute in the baseline case. The accuracy is $51.95\%$ for $\mathcal{L}_{\rm LLM}$ and $71.10\%$ for $\mathcal{L}_{\rm LLM-JEPA}$. Since $\mathcal{L}_{\rm LLM}$ and $\mathcal{L}_{\rm LLM-JEPA}$ share similar $\mathcal{L}_{\rm LLM}$ loss, the $\mathcal{L}_{\rm LLM}$ loss cannot explain the gap between the accuracy. pred stays a constant in $\mathcal{L}_{\rm LLM}$, while is minimized in $\mathcal{L}_{\rm LLM-JEPA}$, hence pred should be the main reason behind the accuracy gap.
  • Figure 4: $t$-SNE plot of $\rm Text$ and $\rm Code$ representations in (a) Baseline that is fine-tuned with NTP loss, (b) LLM-JEPA (ours) with $k=0$. Clearly LLM-JEPA (ours) induced nice structure on the representations while fine-tuning with NTP loss disrupted the structure in the base model. A full version of this figure is in \ref{['sec:structured_rep_full']}.
  • Figure 5: LLM-JEPA converges faster than regular fine-tuning at the same PFLOPs. Furthermore, random JEPA-loss dropout (LD) helps save PFLOPs and boost accuracy at the same amount of compute. $LD=0$ is the regular LLM-JEPA. Learning rate $lr=2\mathrm{e}{-5}$ and $k=1$. $\lambda$ varies.
  • ...and 4 more figures