Table of Contents
Fetching ...

T-JEPA: Augmentation-Free Self-Supervised Learning for Tabular Data

Hugo Thimonier, José Lucas De Melo Costa, Fabrice Popineau, Arpad Rimmel, Bich-Liên Doan

TL;DR

This work tackles the challenge of self-supervised learning for tabular data, where traditional augmentation-based SSL is hard to apply. It introduces T-JEPA, augmentation-free self-supervised learning via latent-space mask prediction with a context encoder, a target encoder updated by EMA, and a predictor, together with a regularization token to prevent collapse. T-JEPA demonstrates consistent improvements across a range of downstream models and datasets, often matching or surpassing gradient-boosted trees, and provides insight into the learned representations through embedding-space analysis and feature-importance alignment. The approach broadens the applicability of SSL to structured data, offering a practical, augmentation-free pathway to leverage unlabeled tabular data for both classification and regression tasks.

Abstract

Self-supervision is often used for pre-training to foster performance on a downstream task by constructing meaningful representations of samples. Self-supervised learning (SSL) generally involves generating different views of the same sample and thus requires data augmentations that are challenging to construct for tabular data. This constitutes one of the main challenges of self-supervision for structured data. In the present work, we propose a novel augmentation-free SSL method for tabular data. Our approach, T-JEPA, relies on a Joint Embedding Predictive Architecture (JEPA) and is akin to mask reconstruction in the latent space. It involves predicting the latent representation of one subset of features from the latent representation of a different subset within the same sample, thereby learning rich representations without augmentations. We use our method as a pre-training technique and train several deep classifiers on the obtained representation. Our experimental results demonstrate a substantial improvement in both classification and regression tasks, outperforming models trained directly on samples in their original data space. Moreover, T-JEPA enables some methods to consistently outperform or match the performance of traditional methods likes Gradient Boosted Decision Trees. To understand why, we extensively characterize the obtained representations and show that T-JEPA effectively identifies relevant features for downstream tasks without access to the labels. Additionally, we introduce regularization tokens, a novel regularization method critical for training of JEPA-based models on structured data.

T-JEPA: Augmentation-Free Self-Supervised Learning for Tabular Data

TL;DR

This work tackles the challenge of self-supervised learning for tabular data, where traditional augmentation-based SSL is hard to apply. It introduces T-JEPA, augmentation-free self-supervised learning via latent-space mask prediction with a context encoder, a target encoder updated by EMA, and a predictor, together with a regularization token to prevent collapse. T-JEPA demonstrates consistent improvements across a range of downstream models and datasets, often matching or surpassing gradient-boosted trees, and provides insight into the learned representations through embedding-space analysis and feature-importance alignment. The approach broadens the applicability of SSL to structured data, offering a practical, augmentation-free pathway to leverage unlabeled tabular data for both classification and regression tasks.

Abstract

Self-supervision is often used for pre-training to foster performance on a downstream task by constructing meaningful representations of samples. Self-supervised learning (SSL) generally involves generating different views of the same sample and thus requires data augmentations that are challenging to construct for tabular data. This constitutes one of the main challenges of self-supervision for structured data. In the present work, we propose a novel augmentation-free SSL method for tabular data. Our approach, T-JEPA, relies on a Joint Embedding Predictive Architecture (JEPA) and is akin to mask reconstruction in the latent space. It involves predicting the latent representation of one subset of features from the latent representation of a different subset within the same sample, thereby learning rich representations without augmentations. We use our method as a pre-training technique and train several deep classifiers on the obtained representation. Our experimental results demonstrate a substantial improvement in both classification and regression tasks, outperforming models trained directly on samples in their original data space. Moreover, T-JEPA enables some methods to consistently outperform or match the performance of traditional methods likes Gradient Boosted Decision Trees. To understand why, we extensively characterize the obtained representations and show that T-JEPA effectively identifies relevant features for downstream tasks without access to the labels. Additionally, we introduce regularization tokens, a novel regularization method critical for training of JEPA-based models on structured data.
Paper Structure (62 sections, 22 equations, 7 figures, 18 tables)

This paper contains 62 sections, 22 equations, 7 figures, 18 tables.

Figures (7)

  • Figure 1: T-JEPA training pipeline. In step (a) a sample $\mathbf{x} \in \mathbb{R}^d$ is pre-processed and masked as detailed in equation \ref{['eq:context_masking']} and fed to the context encoder to obtain a representation in $\mathbb{R}^{l_\mathbf{m}\times h}$ where $l_\mathbf{m}$ is the number of unmasked features for context mask $\mathbf{m}$. In step (b) the unmasked representation of sample $\mathbf{x}$ is fed to the target encoder and the features' representations are selected according to the corresponding target masks, as shown in equation \ref{['eq:mask_target']} and Figure \ref{['fig:reg_token_pipeline']}. In step (c) the output of the context encoder is fed to the predictor to obtain a prediction for each target mask used in step (b). In step (d) we compute the $\ell_2$-distance between the target representations and their predictions.
  • Figure 2: Visualization of the representation space at various epochs of the T-JEPA pretraining on the Jannis (JA) dataset. Each plot depicts the density of transformed points in two dimensions, with darker areas indicating higher density.
  • Figure 3: Training regime of Joint-Embedding Predictive Architectures on tabular (JA) and image (ImageNet-1K) data. We display on the right a randomly selected sample's representations for each critical part of the training process. The subfigures (a) to (d) illustrate the evolving outputs of the context encoder $h_\text{context} \in \mathbb{R}^{d \times h}$. In each heat-map, rows correspond to the $d$ features, while columns represent the $h$ hidden dimensions. (a) describes the initial random initialization, (b) the collapsed equilibrium, (c) the regularization effect pushing the weights outside of the collapsed equilibrium, and (d) the convergence.
  • Figure 4: Regularization token ablation. Training loss for the JA dataset across different numbers of regularization tokens [REG].
  • Figure 5: Pairwise comparison of feature rankings using Kendall's $\tau$ correlation on the JA dataset. Rankings are derived from XGBoost feature importance, permutation importance, T-JEPA's embedding variance ($\sigma_{embed}$), and a random baseline.
  • ...and 2 more figures