Table of Contents
Fetching ...

JEPA as a Neural Tokenizer: Learning Robust Speech Representations with Density Adaptive Attention

Georgios Ioannides, Christos Constantinou, Aman Chadha, Aaron Elkins, Linsey Pang, Ravid Shwartz-Ziv, Yann LeCun

TL;DR

The paper tackles the challenge of learning robust, language-model-friendly speech representations without labeled data by proposing a two-stage framework that decouples representation learning from waveform reconstruction. It combines Joint-Embedding Predictive Architecture (JEPA) with Density Adaptive Attention Mechanisms (DAAM) to yield adaptive temporal feature selection and hierarchical speech structure at 2.5 Hz. In Stage 2, the latent representations are discretized with Finite Scalar Quantization (FSQ) and packed using a mixed-radix scheme to produce 47.5 tokens per second, which can be decoded back to waveform via a HiFi-GAN decoder. The work demonstrates faster convergence and competitive efficiency against neural codecs, providing a reversible, compact tokenization suitable for training downstream language models and other sequence models on speech data.

Abstract

We introduce a two-stage self-supervised framework that combines the Joint-Embedding Predictive Architecture (JEPA) with a Density Adaptive Attention Mechanism (DAAM) for learning robust speech representations. Stage~1 uses JEPA with DAAM to learn semantic audio features via masked prediction in latent space, fully decoupled from waveform reconstruction. Stage~2 leverages these representations for efficient tokenization using Finite Scalar Quantization (FSQ) and a mixed-radix packing scheme, followed by high-fidelity waveform reconstruction with a HiFi-GAN decoder. By integrating Gaussian mixture-based density-adaptive gating into the JEPA encoder, the model performs adaptive temporal feature selection and discovers hierarchical speech structure at a low frame rate of 2.5~Hz. The resulting tokens (47.5 tokens/sec) provide a reversible, highly compressed, and language-model-friendly representation that is competitive with, and often more efficient than, existing neural audio codecs.

JEPA as a Neural Tokenizer: Learning Robust Speech Representations with Density Adaptive Attention

TL;DR

The paper tackles the challenge of learning robust, language-model-friendly speech representations without labeled data by proposing a two-stage framework that decouples representation learning from waveform reconstruction. It combines Joint-Embedding Predictive Architecture (JEPA) with Density Adaptive Attention Mechanisms (DAAM) to yield adaptive temporal feature selection and hierarchical speech structure at 2.5 Hz. In Stage 2, the latent representations are discretized with Finite Scalar Quantization (FSQ) and packed using a mixed-radix scheme to produce 47.5 tokens per second, which can be decoded back to waveform via a HiFi-GAN decoder. The work demonstrates faster convergence and competitive efficiency against neural codecs, providing a reversible, compact tokenization suitable for training downstream language models and other sequence models on speech data.

Abstract

We introduce a two-stage self-supervised framework that combines the Joint-Embedding Predictive Architecture (JEPA) with a Density Adaptive Attention Mechanism (DAAM) for learning robust speech representations. Stage~1 uses JEPA with DAAM to learn semantic audio features via masked prediction in latent space, fully decoupled from waveform reconstruction. Stage~2 leverages these representations for efficient tokenization using Finite Scalar Quantization (FSQ) and a mixed-radix packing scheme, followed by high-fidelity waveform reconstruction with a HiFi-GAN decoder. By integrating Gaussian mixture-based density-adaptive gating into the JEPA encoder, the model performs adaptive temporal feature selection and discovers hierarchical speech structure at a low frame rate of 2.5~Hz. The resulting tokens (47.5 tokens/sec) provide a reversible, highly compressed, and language-model-friendly representation that is competitive with, and often more efficient than, existing neural audio codecs.

Paper Structure

This paper contains 74 sections, 29 equations, 5 figures, 4 tables.

Figures (5)

  • Figure 1: The input waveform is processed by three parallel pathways: (1) an online encoder (trainable, green) that processes the full audio and feeds into a predictor network (yellow) after feature-space masking with a learned mask token, (2) a target encoder (purple) updated via EMA that also processes the full audio to generate $\mathbf{z}_{\text{target}}$, and (3) a masking strategy module (blue) that generates binary masks. The MSE loss is computed only on masked regions between $\mathbf{z}_{\text{predicted}}$ and $\mathbf{z}_{\text{target}}$ (stop-gradient), with gradients backpropagating only through the online encoder and predictor. The target encoder provides stable representations without receiving gradients directly Grill2020BYOL.
  • Figure 2: JEPA online encoder architecture. Input waveform passes through an initial Conv1D layer followed by 5 encoder blocks, each containing Conv1D with stride, SnakeBeta activation, residual blocks, and Gaussian Adaptive Attention gating. Features are projected through a bottleneck Conv1D layer and processed by 8 Conformer blocks (each with FNN, multi-head attention with 16 heads, depthwise convolution, and a second FNN) to produce the final representation $\mathbf{z}$. The target encoder shares this architecture but is updated via exponential moving average rather than backpropagation.
  • Figure 3: JEPA predictor network architecture. The predictor takes masked context features $\mathbf{z}_{\text{masked}}$ and processes them through: (1) an expansion Conv1D layer that doubles the channel dimension, (2) two Conformer blocks separated by an intermediate Conv1D for feature refinement, and (3) a projection Conv1D that reduces back to the original dimensionality, producing predicted features $\mathbf{z}_{\text{pred}}$ at all positions including masked regions.
  • Figure 4: Stage 1 JEPA masked prediction loss (MSE) over training steps. JEPA+DAAM (blue) converges faster and to a lower final loss ($\sim 0.09$) compared to JEPA without DAAM (orange, $\sim 0.17$), demonstrating that Density Adaptive Attention enables more efficient representation learning. Both models use identical architectures except for DAAM gating.
  • Figure 5: HiFi-GAN decoder architecture (Stage 2). Quantized features $\mathbf{z}_q$ are upsampled through a bottleneck Conv1D followed by 5 decoder blocks. Each block contains ConvTranspose1D upsampling and MRF residual blocks with different kernel sizes (3, 7, 11, 15, 23, 32) to capture multi-scale temporal patterns. SnakeBeta activations provide periodic inductive bias for high-fidelity audio generation Ziyin2020Snake.