Table of Contents
Fetching ...

NCART: Neural Classification and Regression Tree for Tabular Data

Jiaqi Luo, Shixin Xu

TL;DR

NCART introduces an interpretable neural network for tabular data by integrating differentiable oblivious decision trees into a ResNet-like architecture, replacing fully-connected layers to retain interpretability while leveraging end-to-end learning. The model employs a differentiable ODT, sparse feature selection via a learnable sparse projection, and ensemble aggregation to form $O(oldsymbol{x})= rac{1}{N} ext{sum}_{i=1}^N w_i t_i(oldsymbol{x}_i^P)$, organized in stacked NCART blocks with residual-style connections. Feature importance is computed similarly to traditional trees, aggregating impurity measures across all trees and layers, enabling interpretable insights. Extensive experiments on 20 OpenML datasets show NCART achieving competitive performance with state-of-the-art deep learning and tree-based methods, often excelling in F1-score and maintaining favorable inference efficiency, while highlighting remaining limitations in regression, missing-value handling, and hyperparameter sensitivity. Overall, NCART offers a practical, interpretable alternative that narrows the gap between interpretability of trees and the predictive power of neural networks for tabular data.

Abstract

Deep learning models have become popular in the analysis of tabular data, as they address the limitations of decision trees and enable valuable applications like semi-supervised learning, online learning, and transfer learning. However, these deep-learning approaches often encounter a trade-off. On one hand, they can be computationally expensive when dealing with large-scale or high-dimensional datasets. On the other hand, they may lack interpretability and may not be suitable for small-scale datasets. In this study, we propose a novel interpretable neural network called Neural Classification and Regression Tree (NCART) to overcome these challenges. NCART is a modified version of Residual Networks that replaces fully-connected layers with multiple differentiable oblivious decision trees. By integrating decision trees into the architecture, NCART maintains its interpretability while benefiting from the end-to-end capabilities of neural networks. The simplicity of the NCART architecture makes it well-suited for datasets of varying sizes and reduces computational costs compared to state-of-the-art deep learning models. Extensive numerical experiments demonstrate the superior performance of NCART compared to existing deep learning models, establishing it as a strong competitor to tree-based models.

NCART: Neural Classification and Regression Tree for Tabular Data

TL;DR

NCART introduces an interpretable neural network for tabular data by integrating differentiable oblivious decision trees into a ResNet-like architecture, replacing fully-connected layers to retain interpretability while leveraging end-to-end learning. The model employs a differentiable ODT, sparse feature selection via a learnable sparse projection, and ensemble aggregation to form , organized in stacked NCART blocks with residual-style connections. Feature importance is computed similarly to traditional trees, aggregating impurity measures across all trees and layers, enabling interpretable insights. Extensive experiments on 20 OpenML datasets show NCART achieving competitive performance with state-of-the-art deep learning and tree-based methods, often excelling in F1-score and maintaining favorable inference efficiency, while highlighting remaining limitations in regression, missing-value handling, and hyperparameter sensitivity. Overall, NCART offers a practical, interpretable alternative that narrows the gap between interpretability of trees and the predictive power of neural networks for tabular data.

Abstract

Deep learning models have become popular in the analysis of tabular data, as they address the limitations of decision trees and enable valuable applications like semi-supervised learning, online learning, and transfer learning. However, these deep-learning approaches often encounter a trade-off. On one hand, they can be computationally expensive when dealing with large-scale or high-dimensional datasets. On the other hand, they may lack interpretability and may not be suitable for small-scale datasets. In this study, we propose a novel interpretable neural network called Neural Classification and Regression Tree (NCART) to overcome these challenges. NCART is a modified version of Residual Networks that replaces fully-connected layers with multiple differentiable oblivious decision trees. By integrating decision trees into the architecture, NCART maintains its interpretability while benefiting from the end-to-end capabilities of neural networks. The simplicity of the NCART architecture makes it well-suited for datasets of varying sizes and reduces computational costs compared to state-of-the-art deep learning models. Extensive numerical experiments demonstrate the superior performance of NCART compared to existing deep learning models, establishing it as a strong competitor to tree-based models.
Paper Structure (30 sections, 15 equations, 9 figures, 6 tables)

This paper contains 30 sections, 15 equations, 9 figures, 6 tables.

Figures (9)

  • Figure 1: Illustration of approximation of
  • Figure 2: Structure of an NCART block with 4 components: Batch Normalization (blue block) for data preprocessing, Feature Selection (red block, Eq. \ref{['e.feat_sel']}), Differentiable Oblivious Tree (green block, Eq. \ref{['e.ndt']}), and Weight Mean for ensemble (yellow block, Eq. \ref{['e.ncart']}).
  • Figure 3: NCART Network architecture. A green-filled block indicates the utilization of all features for feature splitting, while a red-filled block signifies the presence of a feature selection layer.
  • Figure 4: Mean$\pm$std. results of 11 models on different classification datasets. The bold indicates the top result; $OOM$ represents there exists GPU overflow.
  • Figure 5: Rank values of different models on 20 datasets. Fig. (a) and (b) are for classification tasks and Fig. (c) is for regression tasks.
  • ...and 4 more figures