Table of Contents
Fetching ...

InterpreTabNet: Distilling Predictive Signals from Tabular Data by Salient Feature Interpretation

Jacob Si, Wendy Yusi Cheng, Michael Cooper, Rahul G. Krishnan

TL;DR

InterpreTabNet tackles the interpretability gap in TabNet on tabular data by modeling per-step feature masks as latent variables drawn from a Gumbel-Softmax distribution and enforcing sparsity and diversity through a KL-divergence-based regularizer within a conditional variational autoencoder framework. This latent-mask formulation yields sparse, decision-step-specific feature selections that are easier to interpret, while preserving competitive predictive accuracy. The authors further augment intrinsic interpretability with post-hoc, GPT-4 driven textual explanations of learned feature interdependencies, evaluated via human and LM-based assessments. Collectively, InterpreTabNet advances interpretable deep learning for tabular domains and offers a practical toolkit for producing faithful, comprehensible insights in high-stakes settings.

Abstract

Tabular data are omnipresent in various sectors of industries. Neural networks for tabular data such as TabNet have been proposed to make predictions while leveraging the attention mechanism for interpretability. However, the inferred attention masks are often dense, making it challenging to come up with rationales about the predictive signal. To remedy this, we propose InterpreTabNet, a variant of the TabNet model that models the attention mechanism as a latent variable sampled from a Gumbel-Softmax distribution. This enables us to regularize the model to learn distinct concepts in the attention masks via a KL Divergence regularizer. It prevents overlapping feature selection by promoting sparsity which maximizes the model's efficacy and improves interpretability to determine the important features when predicting the outcome. To assist in the interpretation of feature interdependencies from our model, we employ a large language model (GPT-4) and use prompt engineering to map from the learned feature mask onto natural language text describing the learned signal. Through comprehensive experiments on real-world datasets, we demonstrate that InterpreTabNet outperforms previous methods for interpreting tabular data while attaining competitive accuracy.

InterpreTabNet: Distilling Predictive Signals from Tabular Data by Salient Feature Interpretation

TL;DR

InterpreTabNet tackles the interpretability gap in TabNet on tabular data by modeling per-step feature masks as latent variables drawn from a Gumbel-Softmax distribution and enforcing sparsity and diversity through a KL-divergence-based regularizer within a conditional variational autoencoder framework. This latent-mask formulation yields sparse, decision-step-specific feature selections that are easier to interpret, while preserving competitive predictive accuracy. The authors further augment intrinsic interpretability with post-hoc, GPT-4 driven textual explanations of learned feature interdependencies, evaluated via human and LM-based assessments. Collectively, InterpreTabNet advances interpretable deep learning for tabular domains and offers a practical toolkit for producing faithful, comprehensible insights in high-stakes settings.

Abstract

Tabular data are omnipresent in various sectors of industries. Neural networks for tabular data such as TabNet have been proposed to make predictions while leveraging the attention mechanism for interpretability. However, the inferred attention masks are often dense, making it challenging to come up with rationales about the predictive signal. To remedy this, we propose InterpreTabNet, a variant of the TabNet model that models the attention mechanism as a latent variable sampled from a Gumbel-Softmax distribution. This enables us to regularize the model to learn distinct concepts in the attention masks via a KL Divergence regularizer. It prevents overlapping feature selection by promoting sparsity which maximizes the model's efficacy and improves interpretability to determine the important features when predicting the outcome. To assist in the interpretation of feature interdependencies from our model, we employ a large language model (GPT-4) and use prompt engineering to map from the learned feature mask onto natural language text describing the learned signal. Through comprehensive experiments on real-world datasets, we demonstrate that InterpreTabNet outperforms previous methods for interpreting tabular data while attaining competitive accuracy.
Paper Structure (35 sections, 6 equations, 44 figures, 11 tables, 1 algorithm)

This paper contains 35 sections, 6 equations, 44 figures, 11 tables, 1 algorithm.

Figures (44)

  • Figure 1: The InterpreTabNet architecture presents a variational formulation of the TabNet encoder. In our formulation, the weights of the attention masks produced by the TabNet encoder at each step $k$ are treated as the parameters, $\beta_0^{(i)}, ..., \beta_{D-1}^{(i)}$, of a Gumbel-Softmax distribution, $\Lambda_k$, unique to each instance (shown by the red dotted rectangle). This distribution is then sampled to produce a single feature that is highlighted for each feature at each step (purple dot-dashed rectangle). This figure shows $k=2$ steps of the encoder architecture, over $D=5$ features, for $N=3$ samples.
  • Figure 3: Graphical model of InterpreTabNet with $D$ i.i.d samples. Solid lines denote the generative model $p_{\theta}(Y|z, X)p_{\theta}(z|X)$, dashed lines denote the variational approximation $q_{\phi}(z|X, Y)$ to the intractable posterior $p_{\theta}(z|X, Y)$. The variational parameters $\phi$ are learned jointly with the generative model parameters $\theta$.
  • Figure 15: Normalized Training Loss of InterpreTabNet vs. TabNet for the Adult Income Dataset
  • Figure 16: Feature Mask Definition Check
  • Figure : (a) InterpreTabNet Feature Mask ($r_M^* = 9$)
  • ...and 39 more figures