Table of Contents
Fetching ...

TabR: Tabular Deep Learning Meets Nearest Neighbors in 2023

Yury Gorishniy, Ivan Rubachev, Nikolay Kartashev, Daniil Shlenskii, Akim Kotelnikov, Artem Babenko

TL;DR

TabR tackles the persistent gap where gradient-boosted trees outperform deep learning on tabular data by introducing a retrieval-augmented deep learning approach. It integrates a lightweight, single-head retrieval module into a feed-forward backbone, using key-space L2 distance and a correction-based value module to leverage nearest neighbors from the training set. Across public benchmarks, TabR achieves the best average performance among tabular DL models, sets state-of-the-art results on several datasets, and even outperforms GBDT on a mid-scale benchmark, while offering significant efficiency gains over prior retrieval-based models. The work also explores training-speedups via context-freezing and online updates, highlighting the practical viability and future potential of retrieval-augmented tabular DL, including interpretability and continual learning aspects.

Abstract

Deep learning (DL) models for tabular data problems (e.g. classification, regression) are currently receiving increasingly more attention from researchers. However, despite the recent efforts, the non-DL algorithms based on gradient-boosted decision trees (GBDT) remain a strong go-to solution for these problems. One of the research directions aimed at improving the position of tabular DL involves designing so-called retrieval-augmented models. For a target object, such models retrieve other objects (e.g. the nearest neighbors) from the available training data and use their features and labels to make a better prediction. In this work, we present TabR -- essentially, a feed-forward network with a custom k-Nearest-Neighbors-like component in the middle. On a set of public benchmarks with datasets up to several million objects, TabR marks a big step forward for tabular DL: it demonstrates the best average performance among tabular DL models, becomes the new state-of-the-art on several datasets, and even outperforms GBDT models on the recently proposed "GBDT-friendly" benchmark (see Figure 1). Among the important findings and technical details powering TabR, the main ones lie in the attention-like mechanism that is responsible for retrieving the nearest neighbors and extracting valuable signal from them. In addition to the much higher performance, TabR is simple and significantly more efficient compared to prior retrieval-based tabular DL models.

TabR: Tabular Deep Learning Meets Nearest Neighbors in 2023

TL;DR

TabR tackles the persistent gap where gradient-boosted trees outperform deep learning on tabular data by introducing a retrieval-augmented deep learning approach. It integrates a lightweight, single-head retrieval module into a feed-forward backbone, using key-space L2 distance and a correction-based value module to leverage nearest neighbors from the training set. Across public benchmarks, TabR achieves the best average performance among tabular DL models, sets state-of-the-art results on several datasets, and even outperforms GBDT on a mid-scale benchmark, while offering significant efficiency gains over prior retrieval-based models. The work also explores training-speedups via context-freezing and online updates, highlighting the practical viability and future potential of retrieval-augmented tabular DL, including interpretability and continual learning aspects.

Abstract

Deep learning (DL) models for tabular data problems (e.g. classification, regression) are currently receiving increasingly more attention from researchers. However, despite the recent efforts, the non-DL algorithms based on gradient-boosted decision trees (GBDT) remain a strong go-to solution for these problems. One of the research directions aimed at improving the position of tabular DL involves designing so-called retrieval-augmented models. For a target object, such models retrieve other objects (e.g. the nearest neighbors) from the available training data and use their features and labels to make a better prediction. In this work, we present TabR -- essentially, a feed-forward network with a custom k-Nearest-Neighbors-like component in the middle. On a set of public benchmarks with datasets up to several million objects, TabR marks a big step forward for tabular DL: it demonstrates the best average performance among tabular DL models, becomes the new state-of-the-art on several datasets, and even outperforms GBDT models on the recently proposed "GBDT-friendly" benchmark (see Figure 1). Among the important findings and technical details powering TabR, the main ones lie in the attention-like mechanism that is responsible for retrieving the nearest neighbors and extracting valuable signal from them. In addition to the much higher performance, TabR is simple and significantly more efficient compared to prior retrieval-based tabular DL models.
Paper Structure (49 sections, 8 equations, 9 figures, 24 tables)

This paper contains 49 sections, 8 equations, 9 figures, 24 tables.

Figures (9)

  • Figure 1: Comparing DL models with XGBoost chen2016xgboost on 43 regression and classification tasks of middle scale ($\le 50K$ objects) from "Why do tree-based models still outperform deep learning on typical tabular data?" by grinsztajn2022why. TabR marks a significant step forward compared to prior tabular DL models and continues the positive trend for the field.
  • Figure 2: The generic retrieval-based architecture introduced in \ref{['sec:model-architecture']} and used to build TabR. First, a target object and its candidates for retrieval are encoded with the same encoder $E$. Then, the retrieval module $R$ enriches the target object's representation by retrieving and processing relevant objects from the candidates. Finally, predictor $P$ makes a prediction. The bold path highlights the structure of the feed-forward retrieval-free model before the addition of the retrieval module $R$.
  • Figure 3: Encoder $E$ and predictor $P$ introduced in \ref{['fig:architecture-overview']}. $N_E$ and $N_P$ denote the number of Block modules in $E$ and $P$, respectively. The Input Module encapsulates the input processing routines (feature normalization, one-hot encoding, etc.) and assembles a vector input for the subsequent linear layer. In particular, Input Module can contain embeddings for continuous features gorishniy2022embeddings. ($^*$LayerNorm is omitted in the first Block of $E$.)
  • Figure 4: Simplified illustration of the retrieval module $R$ introduced in \ref{['fig:architecture-overview']} (the omitted details are provided in the main text). For the target object's representation $\tilde{x}$, the module takes the $m$ nearest neighbors among the candidates $\{\tilde{x}_i\}$ according to the similarity module $S:(\mathbb{R}^d,\mathbb{R}^d) \rightarrow \mathbb{R}$ and aggregates their values produced by the value module $\mathcal{V}:(\mathbb{R}^d,\mathbb{R}^d,\mathbb{Y}) \rightarrow \mathbb{R}^d$.
  • Figure 5: $\Delta$-context (explained below) averaged over training objects until the early stopping while training TabR-S. On a given epoch, for a given object, $\Delta$-context shows the portion of its context (the top-$m$ candidates and their weights) changed compared to the previous epoch (i.e., the lower the value, the smaller the change; see \ref{['A:sec:impl-context-freeze']} for formal details). The plot shows that context updates become less intensive during the course of training, which motivates the optimization described in \ref{['sec:context-freeze']}.
  • ...and 4 more figures