Table of Contents
Fetching ...

TABCF: Counterfactual Explanations for Tabular Data Using a Transformer-Based VAE

Emmanouil Panagiotou, Manuel Heurich, Tim Landgraf, Eirini Ntoutsi

TL;DR

TABCF is introduced, a CF explanation method that leverages a transformer-based Variational Autoencoder (VAE) tailored for modeling tabular data and uses transformers to learn a continuous latent space and a novel Gumbel-Softmax detokenizer that enables precise categorical reconstruction while preserving end-to-end differentiability.

Abstract

In the field of Explainable AI (XAI), counterfactual (CF) explanations are one prominent method to interpret a black-box model by suggesting changes to the input that would alter a prediction. In real-world applications, the input is predominantly in tabular form and comprised of mixed data types and complex feature interdependencies. These unique data characteristics are difficult to model, and we empirically show that they lead to bias towards specific feature types when generating CFs. To overcome this issue, we introduce TABCF, a CF explanation method that leverages a transformer-based Variational Autoencoder (VAE) tailored for modeling tabular data. Our approach uses transformers to learn a continuous latent space and a novel Gumbel-Softmax detokenizer that enables precise categorical reconstruction while preserving end-to-end differentiability. Extensive quantitative evaluation on five financial datasets demonstrates that TABCF does not exhibit bias toward specific feature types, and outperforms existing methods in producing effective CFs that align with common CF desiderata.

TABCF: Counterfactual Explanations for Tabular Data Using a Transformer-Based VAE

TL;DR

TABCF is introduced, a CF explanation method that leverages a transformer-based Variational Autoencoder (VAE) tailored for modeling tabular data and uses transformers to learn a continuous latent space and a novel Gumbel-Softmax detokenizer that enables precise categorical reconstruction while preserving end-to-end differentiability.

Abstract

In the field of Explainable AI (XAI), counterfactual (CF) explanations are one prominent method to interpret a black-box model by suggesting changes to the input that would alter a prediction. In real-world applications, the input is predominantly in tabular form and comprised of mixed data types and complex feature interdependencies. These unique data characteristics are difficult to model, and we empirically show that they lead to bias towards specific feature types when generating CFs. To overcome this issue, we introduce TABCF, a CF explanation method that leverages a transformer-based Variational Autoencoder (VAE) tailored for modeling tabular data. Our approach uses transformers to learn a continuous latent space and a novel Gumbel-Softmax detokenizer that enables precise categorical reconstruction while preserving end-to-end differentiability. Extensive quantitative evaluation on five financial datasets demonstrates that TABCF does not exhibit bias toward specific feature types, and outperforms existing methods in producing effective CFs that align with common CF desiderata.

Paper Structure

This paper contains 18 sections, 9 equations, 6 figures, 3 tables.

Figures (6)

  • Figure 1: Overview of the counterfactual generation process. The bold arrows indicate data flow, and the dashed arrows indicate backward gradient flow. We iteratively optimize the latent representation $z$, of the counterfactual $x'$, using three distinct loss terms.
  • Figure 2: Overview of the Variational Autoencoder training process. Blue indicates the process for numerical, yellow for categorical features. The detokenizer enables a fully differentiable pipeline for categorical features using the Gumbel-Softmax function for sample reconstruction.
  • Figure 3: SHAP values for $X^{0}_{test}$ on the Adult dataset. Colored dots indicate numerical features and grey dots indicate categorical features. The x-axis displays a positive target class impact to the right and a negative impact to the left.
  • Figure 4: The left plot visualizes a one-hot vector with seven categories ($c_i$ on the x-axis) throughout the optimization process. On the right, the values of the one-hot regularization loss (used by DiCE) are plotted (x-axis) over the number of optimization steps (y-axis).
  • Figure 5: Ablation study for loss hyperparameters on the Adult dataset, for the Validity, Proximity, and Sparsity metrics. The x-axes show increasing values for $\lambda_{prox\_latent}$, the y-axes increasing values for $\lambda_{prox\_input}$. A stronger saturation indicates a better score.
  • ...and 1 more figures