MET: Masked Encoding for Tabular Data
Kushal Majmundar, Sachin Goyal, Praneeth Netrapalli, Prateek Jain
TL;DR
MET introduces a reconstruction-based SSL method for tabular data that avoids handcrafted augmentations. It uses a transformer encoder-decoder with coordinate-wise embeddings, concatenating all coordinate representations for downstream finetuning, and optimizes a combined standard and adversarial reconstruction loss. Empirical results on five diverse tabular benchmarks demonstrate state-of-the-art performance, with substantial gains over prior tabular-SSL methods and robust improvements when labeled data is scarce. The work highlights the effectiveness of high-ratio masking and adversarial reconstruction for learning discriminative tabular representations, while also noting computational considerations and the need for theoretical understanding of the observed benefits.
Abstract
We consider the task of self-supervised representation learning (SSL) for tabular data: tabular-SSL. Typical contrastive learning based SSL methods require instance-wise data augmentations which are difficult to design for unstructured tabular data. Existing tabular-SSL methods design such augmentations in a relatively ad-hoc fashion and can fail to capture the underlying data manifold. Instead of augmentations based approaches for tabular-SSL, we propose a new reconstruction based method, called Masked Encoding for Tabular Data (MET), that does not require augmentations. MET is based on the popular MAE approach for vision-SSL [He et al., 2021] and uses two key ideas: (i) since each coordinate in a tabular dataset has a distinct meaning, we need to use separate representations for all coordinates, and (ii) using an adversarial reconstruction loss in addition to the standard one. Empirical results on five diverse tabular datasets show that MET achieves a new state of the art (SOTA) on all of these datasets and improves up to 9% over current SOTA methods. We shed more light on the working of MET via experiments on carefully designed simple datasets.
