Table of Contents
Fetching ...

TabNet: Attentive Interpretable Tabular Learning

Sercan O. Arik, Tomas Pfister

TL;DR

The paper tackles the challenge of high-performance, interpretable tabular learning by introducing TabNet, a canonical deep architecture that uses sequential attentive feature masks to select salient features at each decision step. It combines sparse, instance-wise feature selection with multi-step reasoning, achieving strong empirical results and enabling both local and global interpretability. A key contribution is the demonstration of self-supervised pre-training for tabular data, which yields substantial gains in downstream supervised tasks and faster convergence. Overall, TabNet provides a competitive, interpretable alternative to tree ensembles and introduces a new direction for unsupervised representation learning in tabular domains.

Abstract

We propose a novel high-performance and interpretable canonical deep tabular data learning architecture, TabNet. TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability and more efficient learning as the learning capacity is used for the most salient features. We demonstrate that TabNet outperforms other neural network and decision tree variants on a wide range of non-performance-saturated tabular datasets and yields interpretable feature attributions plus insights into the global model behavior. Finally, for the first time to our knowledge, we demonstrate self-supervised learning for tabular data, significantly improving performance with unsupervised representation learning when unlabeled data is abundant.

TabNet: Attentive Interpretable Tabular Learning

TL;DR

The paper tackles the challenge of high-performance, interpretable tabular learning by introducing TabNet, a canonical deep architecture that uses sequential attentive feature masks to select salient features at each decision step. It combines sparse, instance-wise feature selection with multi-step reasoning, achieving strong empirical results and enabling both local and global interpretability. A key contribution is the demonstration of self-supervised pre-training for tabular data, which yields substantial gains in downstream supervised tasks and faster convergence. Overall, TabNet provides a competitive, interpretable alternative to tree ensembles and introduces a new direction for unsupervised representation learning in tabular domains.

Abstract

We propose a novel high-performance and interpretable canonical deep tabular data learning architecture, TabNet. TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability and more efficient learning as the learning capacity is used for the most salient features. We demonstrate that TabNet outperforms other neural network and decision tree variants on a wide range of non-performance-saturated tabular datasets and yields interpretable feature attributions plus insights into the global model behavior. Finally, for the first time to our knowledge, we demonstrate self-supervised learning for tabular data, significantly improving performance with unsupervised representation learning when unlabeled data is abundant.

Paper Structure

This paper contains 16 sections, 7 figures, 11 tables.

Figures (7)

  • Figure 1: TabNet's sparse feature selection exemplified for Adult Census Income prediction UCI. Sparse feature selection enables interpretability and better learning as the capacity is used for the most salient features. TabNet employs multiple decision blocks that focus on processing a subset of input features for reasoning. Two decision blocks shown as examples process features that are related to professional occupation and investments, respectively, in order to predict the income level.
  • Figure 2: Self-supervised tabular learning. Real-world tabular datasets have interdependent feature columns, e.g., the education level can be guessed from the occupation, or the gender can be guessed from the relationship. Unsupervised representation learning by masked self-supervised learning results in an improved encoder model for the supervised learning task.
  • Figure 3: Illustration of DT-like classification using conventional DNN blocks (left) and the corresponding decision manifold (right). Relevant features are selected by using multiplicative sparse masks on inputs. The selected features are linearly transformed, and after a bias addition (to represent boundaries) ReLU performs region selection by zeroing the regions. Aggregation of multiple regions is based on addition. As $C_1$ and $C_2$ get larger, the decision boundary gets sharper.
  • Figure 4: (a) TabNet encoder, composed of a feature transformer, an attentive transformer and feature masking. A split block divides the processed representation to be used by the attentive transformer of the subsequent step as well as for the overall output. For each step, the feature selection mask provides interpretable information about the model's functionality, and the masks can be aggregated to obtain global feature important attribution. (b) TabNet decoder, composed of a feature transformer block at each step. (c) A feature transformer block example -- 4-layer network is shown, where 2 are shared across all decision steps and 2 are decision step-dependent. Each layer is composed of a fully-connected (FC) layer, BN and GLU nonlinearity. (d) An attentive transformer block example -- a single layer mapping is modulated with a prior scale information which aggregates how much each feature has been used before the current decision step. sparsemax sparsemax is used for normalization of the coefficients, resulting in sparse selection of the salient features.
  • Figure 5: Feature importance masks $\mathbf{M[i]}$ (that indicate feature selection at $i^{th}$ step) and the aggregate feature importance mask $\mathbf{M_{agg}}$ showing the global instance-wise feature selection, on Syn2 and Syn4 l2x. Brighter colors show a higher value. E.g. for Syn2, only $X_3$-$X_6$ are used.
  • ...and 2 more figures