Table of Contents
Fetching ...

MET: Masked Encoding for Tabular Data

Kushal Majmundar, Sachin Goyal, Praneeth Netrapalli, Prateek Jain

TL;DR

MET introduces a reconstruction-based SSL method for tabular data that avoids handcrafted augmentations. It uses a transformer encoder-decoder with coordinate-wise embeddings, concatenating all coordinate representations for downstream finetuning, and optimizes a combined standard and adversarial reconstruction loss. Empirical results on five diverse tabular benchmarks demonstrate state-of-the-art performance, with substantial gains over prior tabular-SSL methods and robust improvements when labeled data is scarce. The work highlights the effectiveness of high-ratio masking and adversarial reconstruction for learning discriminative tabular representations, while also noting computational considerations and the need for theoretical understanding of the observed benefits.

Abstract

We consider the task of self-supervised representation learning (SSL) for tabular data: tabular-SSL. Typical contrastive learning based SSL methods require instance-wise data augmentations which are difficult to design for unstructured tabular data. Existing tabular-SSL methods design such augmentations in a relatively ad-hoc fashion and can fail to capture the underlying data manifold. Instead of augmentations based approaches for tabular-SSL, we propose a new reconstruction based method, called Masked Encoding for Tabular Data (MET), that does not require augmentations. MET is based on the popular MAE approach for vision-SSL [He et al., 2021] and uses two key ideas: (i) since each coordinate in a tabular dataset has a distinct meaning, we need to use separate representations for all coordinates, and (ii) using an adversarial reconstruction loss in addition to the standard one. Empirical results on five diverse tabular datasets show that MET achieves a new state of the art (SOTA) on all of these datasets and improves up to 9% over current SOTA methods. We shed more light on the working of MET via experiments on carefully designed simple datasets.

MET: Masked Encoding for Tabular Data

TL;DR

MET introduces a reconstruction-based SSL method for tabular data that avoids handcrafted augmentations. It uses a transformer encoder-decoder with coordinate-wise embeddings, concatenating all coordinate representations for downstream finetuning, and optimizes a combined standard and adversarial reconstruction loss. Empirical results on five diverse tabular benchmarks demonstrate state-of-the-art performance, with substantial gains over prior tabular-SSL methods and robust improvements when labeled data is scarce. The work highlights the effectiveness of high-ratio masking and adversarial reconstruction for learning discriminative tabular representations, while also noting computational considerations and the need for theoretical understanding of the observed benefits.

Abstract

We consider the task of self-supervised representation learning (SSL) for tabular data: tabular-SSL. Typical contrastive learning based SSL methods require instance-wise data augmentations which are difficult to design for unstructured tabular data. Existing tabular-SSL methods design such augmentations in a relatively ad-hoc fashion and can fail to capture the underlying data manifold. Instead of augmentations based approaches for tabular-SSL, we propose a new reconstruction based method, called Masked Encoding for Tabular Data (MET), that does not require augmentations. MET is based on the popular MAE approach for vision-SSL [He et al., 2021] and uses two key ideas: (i) since each coordinate in a tabular dataset has a distinct meaning, we need to use separate representations for all coordinates, and (ii) using an adversarial reconstruction loss in addition to the standard one. Empirical results on five diverse tabular datasets show that MET achieves a new state of the art (SOTA) on all of these datasets and improves up to 9% over current SOTA methods. We shed more light on the working of MET via experiments on carefully designed simple datasets.
Paper Structure (22 sections, 2 equations, 4 figures, 6 tables, 1 algorithm)

This paper contains 22 sections, 2 equations, 4 figures, 6 tables, 1 algorithm.

Figures (4)

  • Figure 1: MET Framework for tabular-SSL. Given an input, we mask out a fraction of co-ordinates (features). The masked input is then concatenated with its learnable positional encodings and fed to the transformer based encoder as input. The obtained encoder output (learnt representations) are then passed through the decoder along with the mask token. Recontruction loss is then optimized end-to-end.
  • Figure 2: (a) We analyze the representations learnt by MET on a 10-dimensional binary classification toy dataset, where 10-dimensional points are sampled by concatenating five points sampled i.i.d. from the respective circles. (b) 2D projection of the source data. (c): Mean distance between the learnt representations for the two classes as the SSL using MET proceeds. (d) 2D projection of the representations learnt by MET.
  • Figure 3: We compare the performance (downstream classification accuracy) of MET with various baselines as the fraction of labelled data used for training the downstream classifier is varied. Observe that MET consistently outperforms the baselines even when a smaller fraction of labelled data is used for downstream classifier training.
  • Figure 4: We study the variation in downstream accuracy as the masking ratio is varied in MET, for four tabular classification datasets. We observe that a high masking ratio ($50\%$-$70\%$) generally works the best.