Table of Contents
Fetching ...

Revisiting Nearest Neighbor for Tabular Data: A Deep Tabular Baseline Two Decades Later

Han-Jia Ye, Huai-Hong Yin, De-Chuan Zhan, Wei-Lun Chao

TL;DR

This work revisits Neighbourhood Component Analysis (NCA) as a foundation for tabular learning and progressively augments it with modern deep-learning techniques. Through two main advances—L-NCA (allowing higher-dimensional, learnable embeddings with soft-NN loss) and M-NCA (a nonlinear, deep embedding with Stochastic Neighborhood Sampling and PLR encoding)—the authors create ModernNCA, a strong deep tabular baseline. On a large-scale benchmark of 300 tabular datasets, ModernNCA ranks highly in classification and approaches CatBoost in regression, while offering favorable training speed and memory usage relative to other deep/tabular methods. The study also provides extensive ablations to reveal which components (architecture, loss, encoding, and sampling) most drive performance, and discusses limitations and potential directions for future work.

Abstract

The widespread enthusiasm for deep learning has recently expanded into the domain of tabular data. Recognizing that the advancement in deep tabular methods is often inspired by classical methods, e.g., integration of nearest neighbors into neural networks, we investigate whether these classical methods can be revitalized with modern techniques. We revisit a differentiable version of $K$-nearest neighbors (KNN) -- Neighbourhood Components Analysis (NCA) -- originally designed to learn a linear projection to capture semantic similarities between instances, and seek to gradually add modern deep learning techniques on top. Surprisingly, our implementation of NCA using SGD and without dimensionality reduction already achieves decent performance on tabular data, in contrast to the results of using existing toolboxes like scikit-learn. Further equipping NCA with deep representations and additional training stochasticity significantly enhances its capability, being on par with the leading tree-based method CatBoost and outperforming existing deep tabular models in both classification and regression tasks on 300 datasets. We conclude our paper by analyzing the factors behind these improvements, including loss functions, prediction strategies, and deep architectures. The code is available at https://github.com/qile2000/LAMDA-TALENT.

Revisiting Nearest Neighbor for Tabular Data: A Deep Tabular Baseline Two Decades Later

TL;DR

This work revisits Neighbourhood Component Analysis (NCA) as a foundation for tabular learning and progressively augments it with modern deep-learning techniques. Through two main advances—L-NCA (allowing higher-dimensional, learnable embeddings with soft-NN loss) and M-NCA (a nonlinear, deep embedding with Stochastic Neighborhood Sampling and PLR encoding)—the authors create ModernNCA, a strong deep tabular baseline. On a large-scale benchmark of 300 tabular datasets, ModernNCA ranks highly in classification and approaches CatBoost in regression, while offering favorable training speed and memory usage relative to other deep/tabular methods. The study also provides extensive ablations to reveal which components (architecture, loss, encoding, and sampling) most drive performance, and discusses limitations and potential directions for future work.

Abstract

The widespread enthusiasm for deep learning has recently expanded into the domain of tabular data. Recognizing that the advancement in deep tabular methods is often inspired by classical methods, e.g., integration of nearest neighbors into neural networks, we investigate whether these classical methods can be revitalized with modern techniques. We revisit a differentiable version of -nearest neighbors (KNN) -- Neighbourhood Components Analysis (NCA) -- originally designed to learn a linear projection to capture semantic similarities between instances, and seek to gradually add modern deep learning techniques on top. Surprisingly, our implementation of NCA using SGD and without dimensionality reduction already achieves decent performance on tabular data, in contrast to the results of using existing toolboxes like scikit-learn. Further equipping NCA with deep representations and additional training stochasticity significantly enhances its capability, being on par with the leading tree-based method CatBoost and outperforming existing deep tabular models in both classification and regression tasks on 300 datasets. We conclude our paper by analyzing the factors behind these improvements, including loss functions, prediction strategies, and deep architectures. The code is available at https://github.com/qile2000/LAMDA-TALENT.
Paper Structure (28 sections, 7 equations, 4 figures, 9 tables)

This paper contains 28 sections, 7 equations, 4 figures, 9 tables.

Figures (4)

  • Figure 1: Performance-Efficiency-Memory comparison between ModernNCA and existing methods on classification (a) and regression (b) datasets. Representative tabular prediction methods, including the classical methods (in green), the parametric deep methods (in blue), and the non-parametric/neighborhood-based deep methods (in red), are investigated, based on their records over 300 datasets in \ref{['tab:main_results']} and \ref{['fig:main-results2']}. The average rank among these eight methods is used as the performance measure. We calculate the average training time (in seconds) and the memory usage of the model (denoted by the radius of the circles, where the larger the circle, the bigger the model). ModernNCA achieves high training speed compared to other deep tabular models and has a relatively lower memory usage. L-NCA is our improved linear version of NCA.
  • Figure 2: The critical difference diagrams based on the Wilcoxon-Holm test with a significance level of 0.05 to detect pairwise significance for both classification tasks (evaluated using accuracy) and regression tasks (evaluated using RMSE).
  • Figure 3: The change of average performance rank with different sampling rates among {10%, 30%, 50%, 80%, 100%} in SNS strategy. The dotted line denotes the rank of ModernNCA.
  • Figure 4: Visualization of the embedding space of different methods.