Table of Contents
Fetching ...

Beyond Black-Box Predictions: Identifying Marginal Feature Effects in Tabular Transformer Networks

Anton Thielmann, Arik Reuter, Benjamin Saefken

TL;DR

The paper tackles the tension between predictive accuracy and interpretability in tabular data by proposing NAMformer, a deep tabular model that preserves marginal feature effects. It integrates target-aware embeddings and uncontextualized feature embeddings with a transformer backbone and shallow per-feature nets, enabling identifiable marginal effects via a dropout-based additivity constraint. The approach achieves predictive performance on par with black-box models while providing interpretable marginal effects, demonstrated through simulations and real-data experiments across regression and classification tasks. This work advances intelligible deep learning for tabular data and offers theoretical guarantees for identifiability of marginal effects with practical applicability to high-risk domains.

Abstract

In recent years, deep neural networks have showcased their predictive power across a variety of tasks. Beyond natural language processing, the transformer architecture has proven efficient in addressing tabular data problems and challenges the previously dominant gradient-based decision trees in these areas. However, this predictive power comes at the cost of intelligibility: Marginal feature effects are almost completely lost in the black-box nature of deep tabular transformer networks. Alternative architectures that use the additivity constraints of classical statistical regression models can maintain intelligible marginal feature effects, but often fall short in predictive power compared to their more complex counterparts. To bridge the gap between intelligibility and performance, we propose an adaptation of tabular transformer networks designed to identify marginal feature effects. We provide theoretical justifications that marginal feature effects can be accurately identified, and our ablation study demonstrates that the proposed model efficiently detects these effects, even amidst complex feature interactions. To demonstrate the model's predictive capabilities, we compare it to several interpretable as well as black-box models and find that it can match black-box performances while maintaining intelligibility. The source code is available at https://github.com/OpenTabular/NAMpy.

Beyond Black-Box Predictions: Identifying Marginal Feature Effects in Tabular Transformer Networks

TL;DR

The paper tackles the tension between predictive accuracy and interpretability in tabular data by proposing NAMformer, a deep tabular model that preserves marginal feature effects. It integrates target-aware embeddings and uncontextualized feature embeddings with a transformer backbone and shallow per-feature nets, enabling identifiable marginal effects via a dropout-based additivity constraint. The approach achieves predictive performance on par with black-box models while providing interpretable marginal effects, demonstrated through simulations and real-data experiments across regression and classification tasks. This work advances intelligible deep learning for tabular data and offers theoretical guarantees for identifiability of marginal effects with practical applicability to high-risk domains.

Abstract

In recent years, deep neural networks have showcased their predictive power across a variety of tasks. Beyond natural language processing, the transformer architecture has proven efficient in addressing tabular data problems and challenges the previously dominant gradient-based decision trees in these areas. However, this predictive power comes at the cost of intelligibility: Marginal feature effects are almost completely lost in the black-box nature of deep tabular transformer networks. Alternative architectures that use the additivity constraints of classical statistical regression models can maintain intelligible marginal feature effects, but often fall short in predictive power compared to their more complex counterparts. To bridge the gap between intelligibility and performance, we propose an adaptation of tabular transformer networks designed to identify marginal feature effects. We provide theoretical justifications that marginal feature effects can be accurately identified, and our ablation study demonstrates that the proposed model efficiently detects these effects, even amidst complex feature interactions. To demonstrate the model's predictive capabilities, we compare it to several interpretable as well as black-box models and find that it can match black-box performances while maintaining intelligibility. The source code is available at https://github.com/OpenTabular/NAMpy.

Paper Structure

This paper contains 43 sections, 19 equations, 7 figures, 5 tables.

Figures (7)

  • Figure 1: Feature Encoding. The numerical features are independently encoded $(h(x_j, y))$ and afterwards passed through an embedding layer. The categorical feature are tokenized and also passed through an embedding layer.
  • Figure 2: Training procedure of proposed model structure. The architecture is conceptually very similar to FT-Transformer but allows to identify for marginal feature effects. Note, that feature dropout is only applied during training.
  • Figure 3: Average $R^2$ values over all 9 features. The decision trees are fit, using either the uncontextualized or the contextualized embeddings as training data and the true features as target variables.
  • Figure 4: Critical difference diagram for all innherently interpretable models on the benchmark datasets reported in Table \ref{['tab:interpretable']}. The average ranks across tasks are shown in brackets next to each model. Horizontal lines indicate groups of models with no statistically significant differences in performance. NAMformer achieves the best average rank, with significant differences at the 5% level to the second-best performing model, EB$^2$M. The critical differences are computed using the Conover-Friedman test pereira2015overview, as both average ranks and performance metrics across all datasets are available.
  • Figure 5: Marginal feature predictions for a simple simulated example with 4 variables and the described data generating process. Over 25 runs, and with only 25 bins, NAMformer accurately identifies the marginal effects.
  • ...and 2 more figures