Table of Contents
Fetching ...

A Neural Network Alternative to Tree-based Models

Salvatore Raieli, Nathalie Jeanray, Stéphane Gerart, Sebastien Vachenc, Abdulrahman Altahhan

TL;DR

This work tackles the challenge of applying neural networks to tabular biological data, where tree-based methods excel in interpretability but neural nets struggle with performance. It introduces sTabNet, a sparse, attention-guided neural network that enforces a priori sparsity via a feature-wise adjacency mask and yields intrinsic feature importance, enabling end-to-end interpretability without post-hoc explanations. Across METABRIC, TCGA-BRCA/LUAD, single-cell, and survival analyses, sTabNet achieves competitive or superior performance to XGBoost while offering transfer-learning capabilities and meaningful latent representations. The approach demonstrates robust in-domain and out-of-domain adaptation and suggests that sparse tabular foundations with attention can be scalable and explainable for biomedical AI applications.

Abstract

Tabular datasets are widely used in scientific disciplines such as biology. While these disciplines have already adopted AI methods to enhance their findings and analysis, they mainly use tree-based methods due to their interpretability. At the same time, artificial neural networks have been shown to offer superior flexibility and depth for rich and complex non-tabular problems, but they are falling behind tree-based models for tabular data in terms of performance and interpretability. Although sparsity has been shown to improve the interpretability and performance of ANN models for complex non-tabular datasets, enforcing sparsity structurally and formatively for tabular data before training the model, remains an open question. To address this question, we establish a method that infuses sparsity in neural networks by utilising attention mechanisms to capture the features' importance in tabular datasets. We show that our models, Sparse TABular NET or sTAB-Net with attention mechanisms, are more effective than tree-based models, reaching the state-of-the-art on biological datasets. They further permit the extraction of insights from these datasets and achieve better performance than post-hoc methods like SHAP.

A Neural Network Alternative to Tree-based Models

TL;DR

This work tackles the challenge of applying neural networks to tabular biological data, where tree-based methods excel in interpretability but neural nets struggle with performance. It introduces sTabNet, a sparse, attention-guided neural network that enforces a priori sparsity via a feature-wise adjacency mask and yields intrinsic feature importance, enabling end-to-end interpretability without post-hoc explanations. Across METABRIC, TCGA-BRCA/LUAD, single-cell, and survival analyses, sTabNet achieves competitive or superior performance to XGBoost while offering transfer-learning capabilities and meaningful latent representations. The approach demonstrates robust in-domain and out-of-domain adaptation and suggests that sparse tabular foundations with attention can be scalable and explainable for biomedical AI applications.

Abstract

Tabular datasets are widely used in scientific disciplines such as biology. While these disciplines have already adopted AI methods to enhance their findings and analysis, they mainly use tree-based methods due to their interpretability. At the same time, artificial neural networks have been shown to offer superior flexibility and depth for rich and complex non-tabular problems, but they are falling behind tree-based models for tabular data in terms of performance and interpretability. Although sparsity has been shown to improve the interpretability and performance of ANN models for complex non-tabular datasets, enforcing sparsity structurally and formatively for tabular data before training the model, remains an open question. To address this question, we establish a method that infuses sparsity in neural networks by utilising attention mechanisms to capture the features' importance in tabular datasets. We show that our models, Sparse TABular NET or sTAB-Net with attention mechanisms, are more effective than tree-based models, reaching the state-of-the-art on biological datasets. They further permit the extraction of insights from these datasets and achieve better performance than post-hoc methods like SHAP.

Paper Structure

This paper contains 32 sections, 2 equations, 13 figures, 3 tables, 1 algorithm.

Figures (13)

  • Figure 1: A sparse and interpretable neural network. A. sTabNet Architecture: Features can be grouped according to prior knowledge or by using unsupervised learning (clustering or random walk) to build a matrix A where rows are features and columns are clusters (neurons). In this sparse model, a feature is connected to a neuron (which represents a cluster) only if it is a member of the cluster. B. sTabNet Sparsity:The representation and the definition of the classical dense layer (left) and of the proposed sparse layer (right). The sparse layer is identical to the dense layer except for the Hadamard product between the weight matrix W and the matrix A. C. sTabNet Grouping: (left) The matrix A can be intended as a compressed view of an adjacency matrix of the feature graph. The neuron can also be defined as a random walk in the feature graph, thus learning a local approximation of the neighborhood of a feature. Alternatively, one can use clustering (not shown in the figure). (right) Unrolling of the process in the left: When information about features in a dataset is not present, we calculated the cosine similarity matrix of the features. We assigned an edge between two features if their similarity is higher than 0.5. We performed random walks on the obtained graph and used the obtained random walks to build the sparse matrix in the modified layer. D. sTAbNet as a Tabular Foundational Model: A scheme of sTabNet used for different tasks and data types. The same architecture can be used for common and challenging biological tasks (binary/multiclass classification, censored regression) and complex data (RNAseq, single cell, and multi-omics data). . sTabNet has been tested with real-world datasets for all these tasks. On the left, we are showing that the model can be trained on a dataset, and then the trained model can be used for other datasets and tasks through fine-tuning or feature extraction.
  • Figure 2: Attention mechanisms are a measure of feature importance. A-F. Each boxplot represents 100-fold hold-out validation; a lower coefficient represents a harder multiclassification task. A. Multi-classification accuracy in XGBoost with an increase in separation difficulty. B. Separation between the average importance weight (XGBoost's feature importance) assigned to real informative and non-informative features. C. Separation between the average importance weight (SHAP value) assigned to informative and non-informative features. D. Multi-classification accuracy in sTabNet with an increase in separation difficulty. E. Separation between the average importance weight (feature attention weight) assigned to real informative and non-informative features. F-G. Accuracy is plotted for each different model (XGBoost and sTabNet). The shade represents the standard deviation (10 different models for each removed feature). F. MoRF analysis. G. LeRF analysis. H-I. The shade represents the standard deviation (10 different models trained). H. Feature importance for XGBoost when increasing the number of non-informative features. I. Feature importance for the sTabNet with an increase in the number of non-informative features.
  • Figure 3: sTabNet provides a foundational model to perform in-domain and out-of-domain fine-tuning, it is interpretable and outperforms tree-based models in general.A Comparative table between XGBoost and sTabNet and CNN on METABRIC multi-omics dataset (multiclass classification). B in-domain (breast cancer) and out-of-domain (lung cancer) adaptation. The model was trained on the Metabric dataset, then fine-tuned on other datasets. Accuracy on TCGA-BRCA (same domain of the METABRIC dataset) and TCGA-LUAD (different domain from original dataset) for fine-tuning or feature extraction. C Binary and multiclass classification accuracy for sTabNet and XGBoost on single-cell data (breast cancer (GSE161529)). D Concordance index for METABRIC survival analysis. E. Disease enrichment for the 100 top genes according to attention importance. F. Disease enrichment for the 100 top genes according to attention importance from METABRIC. change here B Multiclass classification accuracy for XGBoost, CNN, and sTabNet (METABRIC dataset).
  • Figure 4: Generality of the proposed sTabNet architecture. sTabNets are competitive to tree-based models for tabular data. A The algorithm that involves a random walk process is used to adapt the model to any tabular dataset. We calculated the similarity between the dataset's features (cosine similarity matrix) and used the obtained matrix to generate a feature graph. We conducted random walks on the feature graphs to explore the local neighbourhood of each feature. We use that knowledge to build the sparse matrix of the neural network B Binary classification accuracy for the tabular benchmark (100 models for 3 techniques: we generated 10 models corresponding to 10 runs -each with a random training/testing split- for 10 datasets). C Feature presence in the top random walks (100 experiments on the pol dataset-each with a random training/testing split) D Feature graph for the pol dataset (isolated nodes are removed), the red highlightted nodes represent the top 5 features present in the top random walk. E, F and G Ablation study: accuracy (E plot) or false positives (F plot) or false negative (G plot) performance for three datasets removing 5 random features vs. removing the 5 top features of the random walk process (10 experiments for each dataset).
  • Figure 5: Fig. S1 Attention mechanisms are a measure of feature importance. A-C. Each boxplot represents 100-fold hold-out validation, a lower coefficient represents a harder multiclassification task. The upper plot represents multi-classification accuracy in sparse net with an increase in separation difficulty. The lower plot represents the separation between the average importance weight (feature attention weight) assigned to real informative features and to not informative features.A. Bahdanau inspired attention mechanism. B. Luong-inspired attention mechanism. C. Graves inspired attention mechanism.
  • ...and 8 more figures