Table of Contents
Fetching ...

When Do Neural Nets Outperform Boosted Trees on Tabular Data?

Duncan McElfresh, Sujay Khandagale, Jonathan Valverde, Vishak Prasad C, Benjamin Feuer, Chinmay Hegde, Ganesh Ramakrishnan, Micah Goldblum, Colin White

TL;DR

This work challenges the prevailing view that neural nets universally outperform gradient-boosted trees on tabular data by conducting the largest tabular-data study to date: 19 algorithms evaluated across 176 OpenML datasets with extensive hyperparameter tuning and metafeature analyses. The results show no universal winner; simple baselines or light tuning on strong GBDTs often match or exceed neural nets, while TabPFN shows exceptional performance on small datasets, and GBDTs excel on irregular or large datasets. The authors also introduce TabZilla, a benchmark suite of 36 hard tabular datasets with open-source code and raw results to accelerate future research and enable principled method selection based on dataset properties. These findings provide practical guidelines for practitioners and a resource for researchers to systematically compare tabular-methods and study failure modes. The work thus advances understanding of when NNs or GBDTs are advantageous and offers concrete tools to improve tabular-data modeling moving forward.

Abstract

Tabular data is one of the most commonly used types of data in machine learning. Despite recent advances in neural nets (NNs) for tabular data, there is still an active discussion on whether or not NNs generally outperform gradient-boosted decision trees (GBDTs) on tabular data, with several recent works arguing either that GBDTs consistently outperform NNs on tabular data, or vice versa. In this work, we take a step back and question the importance of this debate. To this end, we conduct the largest tabular data analysis to date, comparing 19 algorithms across 176 datasets, and we find that the 'NN vs. GBDT' debate is overemphasized: for a surprisingly high number of datasets, either the performance difference between GBDTs and NNs is negligible, or light hyperparameter tuning on a GBDT is more important than choosing between NNs and GBDTs. A remarkable exception is the recently-proposed prior-data fitted network, TabPFN: although it is effectively limited to training sets of size 3000, we find that it outperforms all other algorithms on average, even when randomly sampling 3000 training datapoints. Next, we analyze dozens of metafeatures to determine what properties of a dataset make NNs or GBDTs better-suited to perform well. For example, we find that GBDTs are much better than NNs at handling skewed or heavy-tailed feature distributions and other forms of dataset irregularities. Our insights act as a guide for practitioners to determine which techniques may work best on their dataset. Finally, with the goal of accelerating tabular data research, we release the TabZilla Benchmark Suite: a collection of the 36 'hardest' of the datasets we study. Our benchmark suite, codebase, and all raw results are available at https://github.com/naszilla/tabzilla.

When Do Neural Nets Outperform Boosted Trees on Tabular Data?

TL;DR

This work challenges the prevailing view that neural nets universally outperform gradient-boosted trees on tabular data by conducting the largest tabular-data study to date: 19 algorithms evaluated across 176 OpenML datasets with extensive hyperparameter tuning and metafeature analyses. The results show no universal winner; simple baselines or light tuning on strong GBDTs often match or exceed neural nets, while TabPFN shows exceptional performance on small datasets, and GBDTs excel on irregular or large datasets. The authors also introduce TabZilla, a benchmark suite of 36 hard tabular datasets with open-source code and raw results to accelerate future research and enable principled method selection based on dataset properties. These findings provide practical guidelines for practitioners and a resource for researchers to systematically compare tabular-methods and study failure modes. The work thus advances understanding of when NNs or GBDTs are advantageous and offers concrete tools to improve tabular-data modeling moving forward.

Abstract

Tabular data is one of the most commonly used types of data in machine learning. Despite recent advances in neural nets (NNs) for tabular data, there is still an active discussion on whether or not NNs generally outperform gradient-boosted decision trees (GBDTs) on tabular data, with several recent works arguing either that GBDTs consistently outperform NNs on tabular data, or vice versa. In this work, we take a step back and question the importance of this debate. To this end, we conduct the largest tabular data analysis to date, comparing 19 algorithms across 176 datasets, and we find that the 'NN vs. GBDT' debate is overemphasized: for a surprisingly high number of datasets, either the performance difference between GBDTs and NNs is negligible, or light hyperparameter tuning on a GBDT is more important than choosing between NNs and GBDTs. A remarkable exception is the recently-proposed prior-data fitted network, TabPFN: although it is effectively limited to training sets of size 3000, we find that it outperforms all other algorithms on average, even when randomly sampling 3000 training datapoints. Next, we analyze dozens of metafeatures to determine what properties of a dataset make NNs or GBDTs better-suited to perform well. For example, we find that GBDTs are much better than NNs at handling skewed or heavy-tailed feature distributions and other forms of dataset irregularities. Our insights act as a guide for practitioners to determine which techniques may work best on their dataset. Finally, with the goal of accelerating tabular data research, we release the TabZilla Benchmark Suite: a collection of the 36 'hardest' of the datasets we study. Our benchmark suite, codebase, and all raw results are available at https://github.com/naszilla/tabzilla.
Paper Structure (49 sections, 15 figures, 23 tables)

This paper contains 49 sections, 15 figures, 23 tables.

Figures (15)

  • Figure 1: Overview of our work. We start by conducting the largest study on tabular data to date (left); we analyze the importance of algorithm selection ('NNs vs. GBDTs') as well as metafeatures (middle); and based on our study, we release TabZilla, a collection of the hardest tabular datasets.
  • Figure 2: Median runtime vs. median normalized accuracy for each algorithm, over 98 datasets. The bars span the 20th to 80th percentile over all datasets.
  • Figure 3: Critical difference plot comparing all algorithms according to their mean log loss rank over 98 datasets. Each algorithm's average rank is shown as a horizontal line on the axis. Sets of algorithms which are not significantly different are connected by a horizontal black bar. Algorithm family is indicated by a marker next to the algorithm name: red "X" indicates a neural net, blue circle indicates a baseline algorithm, green triangles indicate GBDTs, and purple squares indicate a PFN.
  • Figure 4: Left: Venn diagram of the number datasets where each algorithm is 'high-performing' for each algorithm class, over all 176 datasets. An algorithm is high-performing if its test accuracy after 0-1 scaling is at least 0.99 (we show 0.9999 in \ref{['app:relative-performance']}). Right: the performance improvement of hyperparameter tuning on CatBoost, compared to the absolute performance difference between the best neural net and the best GBDT using default hyperparameters. Each point indicates the normalized log loss of one dataset, Points on or below the dotted line indicate that the performance improvement due to tuning is as high as the difference between NN-GBDT algorithm selection.
  • Figure 5: Left: scatterplot of the best algorithm on all 176 datasets across metafeatures. The vertical axis indicates the dataset size, and the horizontal axis combines five dataset metafeatures related to irregularity. Right: scatterplot of the difference in normalized log loss between XGBoost and ResNet, by dataset size (middle subplot) and irregularity (right subplot). The irregularity feature is a linear combination of five standardized dataset attributes: the minimum eigenvalue of the feature covariance matrix (-0.33), the skewness of the standard deviation of all features (0.23), the skewness of the range of all features (0.22), the interquartile range of the harmonic mean of all features (0.21), and the standard deviation of the kurtosis of all features (0.21).
  • ...and 10 more figures