Table of Contents
Fetching ...

TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling

Yury Gorishniy, Akim Kotelnikov, Artem Babenko

TL;DR

TabM shows that strong baselines for tabular DL can be achieved with a simple MLP backbone augmented by parameter-efficient ensembling. By training $k$ implicit submodels in parallel while sharing most weights (BatchEnsemble-style adapters), TabM delivers ensemble-like performance with far lower cost than traditional deep ensembles. Large-scale benchmarks across 46 public datasets reveal TabM as the top-performing tabular DL method, with attention- and retrieval-based approaches lagging in both reliability and efficiency, especially on domain-aware splits. The work also demonstrates that the ensemble benefit arises from collective training and weight sharing, and that submodel pruning can reduce inference cost without sacrificing performance, indicating practical paths for deployment.

Abstract

Deep learning architectures for supervised learning on tabular data range from simple multilayer perceptrons (MLP) to sophisticated Transformers and retrieval-augmented methods. This study highlights a major, yet so far overlooked opportunity for designing substantially better MLP-based tabular architectures. Namely, our new model TabM relies on efficient ensembling, where one TabM efficiently imitates an ensemble of MLPs and produces multiple predictions per object. Compared to a traditional deep ensemble, in TabM, the underlying implicit MLPs are trained simultaneously, and (by default) share most of their parameters, which results in significantly better performance and efficiency. Using TabM as a new baseline, we perform a large-scale evaluation of tabular DL architectures on public benchmarks in terms of both task performance and efficiency, which renders the landscape of tabular DL in a new light. Generally, we show that MLPs, including TabM, form a line of stronger and more practical models compared to attention- and retrieval-based architectures. In particular, we find that TabM demonstrates the best performance among tabular DL models. Then, we conduct an empirical analysis on the ensemble-like nature of TabM. We observe that the multiple predictions of TabM are weak individually, but powerful collectively. Overall, our work brings an impactful technique to tabular DL and advances the performance-efficiency trade-off with TabM -- a simple and powerful baseline for researchers and practitioners.

TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling

TL;DR

TabM shows that strong baselines for tabular DL can be achieved with a simple MLP backbone augmented by parameter-efficient ensembling. By training implicit submodels in parallel while sharing most weights (BatchEnsemble-style adapters), TabM delivers ensemble-like performance with far lower cost than traditional deep ensembles. Large-scale benchmarks across 46 public datasets reveal TabM as the top-performing tabular DL method, with attention- and retrieval-based approaches lagging in both reliability and efficiency, especially on domain-aware splits. The work also demonstrates that the ensemble benefit arises from collective training and weight sharing, and that submodel pruning can reduce inference cost without sacrificing performance, indicating practical paths for deployment.

Abstract

Deep learning architectures for supervised learning on tabular data range from simple multilayer perceptrons (MLP) to sophisticated Transformers and retrieval-augmented methods. This study highlights a major, yet so far overlooked opportunity for designing substantially better MLP-based tabular architectures. Namely, our new model TabM relies on efficient ensembling, where one TabM efficiently imitates an ensemble of MLPs and produces multiple predictions per object. Compared to a traditional deep ensemble, in TabM, the underlying implicit MLPs are trained simultaneously, and (by default) share most of their parameters, which results in significantly better performance and efficiency. Using TabM as a new baseline, we perform a large-scale evaluation of tabular DL architectures on public benchmarks in terms of both task performance and efficiency, which renders the landscape of tabular DL in a new light. Generally, we show that MLPs, including TabM, form a line of stronger and more practical models compared to attention- and retrieval-based architectures. In particular, we find that TabM demonstrates the best performance among tabular DL models. Then, we conduct an empirical analysis on the ensemble-like nature of TabM. We observe that the multiple predictions of TabM are weak individually, but powerful collectively. Overall, our work brings an impactful technique to tabular DL and advances the performance-efficiency trade-off with TabM -- a simple and powerful baseline for researchers and practitioners.

Paper Structure

This paper contains 49 sections, 11 figures, 17 tables.

Figures (11)

  • Figure 1: (Upper left) A high-level illustration of TabM. One TabM represents an ensemble of $k$ MLPs processing $k$ inputs in parallel. The remaining parts of the figure are three different parametrizations of the $k$ MLP backbones. (Upper right)$\hbox{TabM}_\text{packed}$ consists of $k$ fully independent MLPs. (Lower left) TabM is obtained by injecting three non-shared adapters $R$, $S$, $B$ in each of the $N$ linear layers of one MLP ($^*$ the initialization differs from wen2020batchensemble). (Lower right)$\hbox{TabM}_\text{mini}$ is obtained by keeping only the very first adapter $R$ of TabM and removing the remaining $3N - 1$ adapters. (Details) Input transformations such as one-hot-encoding or feature embeddings gorishniy2022embeddings are omitted for simplicity. Drop denotes dropout srivastava2014dropout.
  • Figure 2: The performance of models described in \ref{['sec:model-design']} on 46 datasets from \ref{['tab:datasets']}; plus several baselines on the left. For a given model, one dot on a jitter plot describes the performance score on one of the 46 datasets. The box plots describe the percentiles of the jitter plots: the boxes describe the 25th, 50th, and 75th percentiles, and the whiskers describe the 10th and 90th percentiles. Outliers are clipped. The numbers at the bottom are the mean and standard deviations over the jitter plots. For each model, hyperparameters are tuned. "$\text{Model}^{\times k}$" denotes an ensemble of $k$ models.
  • Figure 3: The task performance of tabular models on the 46 datasets from \ref{['tab:datasets']}. (Left) The mean and standard deviations of the performance ranks over all datasets summarize the head-to-head comparison between the models on all datasets. (Middle & Right) The relative performance w.r.t. the plain multilayer perceptron (MLP) allows reasoning about the scale and consistency of improvements over this simple baseline. One dot of a jitter plot corresponds to the performance of a model on one of the 46 datasets. The box plots visualize the 10th, 25th, 50th, 75th, and 90th percentiles of the jitter plots. Outliers are clipped. The separation in random and domain-aware dataset splits is explained in \ref{['sec:model-preliminaries']}. ($^*$Evaluated under the common protocol without data augmentations)
  • Figure 4: Training times (left) and inference throughput (right) of the models from \ref{['fig:performance']}. One dot represents a measurement on one dataset. $\hbox{TabM}_\text{mini}^{\dagger *}$ is the optimized $\hbox{TabM}_\text{mini}^\dagger$ (see \ref{['sec:evaluation-efficiency']}).
  • Figure 5: The training profiles of $\hbox{TabM}_\text{mini}^{k=32}$ and $\hbox{TabM}_\text{mini}^{k=1}$ as described in \ref{['sec:analysis-optimization']}. (Upper) The training curves. $k=32[i]$ represents the mean individual loss over the $32$ submodels. (Lower) Same as the first row, but in the train-test coordinates: each dot represents some epoch from the first row, and the training generally goes from left to right. This allows reasoning about overfitting by comparing test loss values for a given train loss value.
  • ...and 6 more figures