Table of Contents
Fetching ...

Semantic Tube Prediction: Beating LLM Data Efficiency with JEPA

Hai Huang, Yann LeCun, Randall Balestriero

TL;DR

The Geodesic Hypothesis is introduced, positing that token sequences trace geodesics on a smooth semantic manifold and are therefore locally linear, and a novel Semantic Tube Prediction (STP) task is proposed, a JEPA-style regularizer that confines hidden-state trajectories to a tubular neighborhood of the geodesic.

Abstract

Large Language Models (LLMs) obey consistent scaling laws -- empirical power-law fits that predict how loss decreases with compute, data, and parameters. While predictive, these laws are descriptive rather than prescriptive: they characterize typical training, not optimal training. Surprisingly few works have successfully challenged the data-efficiency bounds implied by these laws -- which is our primary focus. To that end, we introduce the Geodesic Hypothesis, positing that token sequences trace geodesics on a smooth semantic manifold and are therefore locally linear. Building on this principle, we propose a novel Semantic Tube Prediction (STP) task, a JEPA-style regularizer that confines hidden-state trajectories to a tubular neighborhood of the geodesic. STP generalizes JEPA to language without requiring explicit multi-view augmentations. We show this constraint improves signal-to-noise ratio, and consequently preserves diversity by preventing trajectory collisions during inference. Empirically, STP allows LLMs to match baseline accuracy with 16$\times$ less training data on the NL-RX-SYNTH dataset, directly violating the data term of Chinchilla-style scaling laws and demonstrating that principled geometric priors can surpass brute-force scaling. Code is available at https://github.com/galilai-group/llm-jepa#stp.

Semantic Tube Prediction: Beating LLM Data Efficiency with JEPA

TL;DR

The Geodesic Hypothesis is introduced, positing that token sequences trace geodesics on a smooth semantic manifold and are therefore locally linear, and a novel Semantic Tube Prediction (STP) task is proposed, a JEPA-style regularizer that confines hidden-state trajectories to a tubular neighborhood of the geodesic.

Abstract

Large Language Models (LLMs) obey consistent scaling laws -- empirical power-law fits that predict how loss decreases with compute, data, and parameters. While predictive, these laws are descriptive rather than prescriptive: they characterize typical training, not optimal training. Surprisingly few works have successfully challenged the data-efficiency bounds implied by these laws -- which is our primary focus. To that end, we introduce the Geodesic Hypothesis, positing that token sequences trace geodesics on a smooth semantic manifold and are therefore locally linear. Building on this principle, we propose a novel Semantic Tube Prediction (STP) task, a JEPA-style regularizer that confines hidden-state trajectories to a tubular neighborhood of the geodesic. STP generalizes JEPA to language without requiring explicit multi-view augmentations. We show this constraint improves signal-to-noise ratio, and consequently preserves diversity by preventing trajectory collisions during inference. Empirically, STP allows LLMs to match baseline accuracy with 16 less training data on the NL-RX-SYNTH dataset, directly violating the data term of Chinchilla-style scaling laws and demonstrating that principled geometric priors can surpass brute-force scaling. Code is available at https://github.com/galilai-group/llm-jepa#stp.
Paper Structure (28 sections, 9 theorems, 44 equations, 13 figures, 5 tables)

This paper contains 28 sections, 9 theorems, 44 equations, 13 figures, 5 tables.

Key Result

Proposition 2.1

The LLM training process can be modeled as a solution in the token sequence space $\mathbb{R}^{T \times d_{\rm model}}$ to the ODE:

Figures (13)

  • Figure 1: Semantic Tube improves data efficiency. (a) We hypothesize that error-free hidden state trajectories are geodesics, which are locally linear and approximated by the Semantic Tube. The dotted line depicts a trajectory distorted by training loss. Deviations perpendicular to the tube constitute noise, while the component along the geodesic represents the signal. (b) With our approach ($\mathcal{L}_{\rm NTP} + \mathcal{L}_{\rm STP}$), accuracy shows a negligible drop when the training dataset is halved, and it matches full-dataset standard fine-tuning ($\mathcal{L}_{\rm NTP}$) accuracy using only $\frac{1}{16}$ of the training data. In contrast, $\mathcal{L}_{\rm NTP}$ degrades significantly when the dataset is halved.
  • Figure 2: Two hidden state trajectories with similar prefixes pass through the Voronoi cell of the "researcher" token at different locations, leading to different next hidden states and hence different next tokens. Since $\mathcal{L}_{\rm NTP}$ cannot guarantee that $h_t$ converges to $h^\ast_t$ (optimal hidden state), $h_t$ can be misplaced on another geodesic. This leads to mode collapse (the red dotted line mistakenly continues the generation, misattributing Hinton's Nobel Prize to an arbitrary person, or if the error deviates in the opposite direction and precludes a winner).
  • Figure 3: When the sentence aligns on a geodesic, the concept direction naturally aligns.
  • Figure 4: Loss landscape. (a) When $\mathcal{L}_{\rm NTP}$ plateaus, $\mathcal{L}_{\rm STP}$ continues to decrease. Furthermore, minimizing $\mathcal{L}_{\rm NTP}$ does not automatically minimize $\mathcal{L}_{\rm STP}$. (b) Across a wide range of $\lambda$, increasing $\lambda$ on a logarithmic scale reduces $\mathcal{L}_{\rm STP}$ linearly, while $\mathcal{L}_{\rm NTP}$ remains unchanged.
  • Figure 5: Semantic Tube ($\mathcal{L}_{\rm NTP} + \mathcal{L}_{\rm STP}$, our approach) demonstrates superior performance across (a) datasets, (b) model families, and (c) model sizes compared to regular fine-tuning ($\mathcal{L}_{\rm NTP}$) and LLM-JEPA ($\mathcal{L}_{\rm NTP} + \mathcal{L}_{\rm JEPA}$).
  • ...and 8 more figures

Theorems & Definitions (13)

  • Proposition 2.1: Training ODE
  • Definition 3.1: Local Linearity
  • Lemma 3.2: Straightening Lemma
  • Theorem 3.3: Semantic Tube
  • proof : Proof Sketch
  • Corollary 3.4: Random Tube
  • Proposition 2.1: Inference SDE
  • Proposition 6.1: Inference Cone
  • proof : Proof
  • Lemma 8.1: Data Efficiency
  • ...and 3 more