Table of Contents
Fetching ...

Deep Trees for (Un)structured Data: Tractability, Performance, and Interpretability

Dimitris Bertsimas, Lisa Everest, Jiayi Gu, Matthew Peroni, Vasiliki Stoumpou

TL;DR

This work develops a tractable approach to growing GSTs, given by the DeepTree algorithm, which, in addition to new regularization terms, produces high-quality models with far fewer nodes and greater interpretability than traditional soft trees.

Abstract

Decision Trees have remained a popular machine learning method for tabular datasets, mainly due to their interpretability. However, they lack the expressiveness needed to handle highly nonlinear or unstructured datasets. Motivated by recent advances in tree-based machine learning (ML) techniques and first-order optimization methods, we introduce Generalized Soft Trees (GSTs), which extend soft decision trees (STs) and are capable of processing images directly. We demonstrate their advantages with respect to tractability, performance, and interpretability. We develop a tractable approach to growing GSTs, given by the DeepTree algorithm, which, in addition to new regularization terms, produces high-quality models with far fewer nodes and greater interpretability than traditional soft trees. We test the performance of our GSTs on benchmark tabular and image datasets, including MIMIC-IV, MNIST, Fashion MNIST, CIFAR-10 and Celeb-A. We show that our approach outperforms other popular tree methods (CART, Random Forests, XGBoost) in almost all of the datasets, with Convolutional Trees having a significant edge in the hardest CIFAR-10 and Fashion MNIST datasets. Finally, we explore the interpretability of our GSTs and find that even the most complex GSTs are considerably more interpretable than deep neural networks. Overall, our approach of Generalized Soft Trees provides a tractable method that is high-performing on (un)structured datasets and preserves interpretability more than traditional deep learning methods.

Deep Trees for (Un)structured Data: Tractability, Performance, and Interpretability

TL;DR

This work develops a tractable approach to growing GSTs, given by the DeepTree algorithm, which, in addition to new regularization terms, produces high-quality models with far fewer nodes and greater interpretability than traditional soft trees.

Abstract

Decision Trees have remained a popular machine learning method for tabular datasets, mainly due to their interpretability. However, they lack the expressiveness needed to handle highly nonlinear or unstructured datasets. Motivated by recent advances in tree-based machine learning (ML) techniques and first-order optimization methods, we introduce Generalized Soft Trees (GSTs), which extend soft decision trees (STs) and are capable of processing images directly. We demonstrate their advantages with respect to tractability, performance, and interpretability. We develop a tractable approach to growing GSTs, given by the DeepTree algorithm, which, in addition to new regularization terms, produces high-quality models with far fewer nodes and greater interpretability than traditional soft trees. We test the performance of our GSTs on benchmark tabular and image datasets, including MIMIC-IV, MNIST, Fashion MNIST, CIFAR-10 and Celeb-A. We show that our approach outperforms other popular tree methods (CART, Random Forests, XGBoost) in almost all of the datasets, with Convolutional Trees having a significant edge in the hardest CIFAR-10 and Fashion MNIST datasets. Finally, we explore the interpretability of our GSTs and find that even the most complex GSTs are considerably more interpretable than deep neural networks. Overall, our approach of Generalized Soft Trees provides a tractable method that is high-performing on (un)structured datasets and preserves interpretability more than traditional deep learning methods.

Paper Structure

This paper contains 18 sections, 11 equations, 7 figures, 16 tables, 3 algorithms.

Figures (7)

  • Figure 1: Results for Fashion MNIST GSTs. We observe that by adding a small number of leaves to a full a tree can improve the out-of-sample performance. For example, for Hyperplane Trees, a tree with less than 200 leaves performs better than a depth-10 Full Tree, with 1024 leaves.
  • Figure 2: Results for CIFAR-10 GSTs. We observe that grown trees offer a considerable performance enhancement. For example, in the case of Convolutional Trees, we can grow a depth-7 tree and achieve average accuracy comparable to a full depth-10 tree.
  • Figure 3: Sample node splits for MIMIV-IV depth 4 full tree. We present up to three features with high absolute coefficients for each node. We observe that the tree trained with sample penalty learn different splits at each inner node.
  • Figure 4: Anatomy of a Hyperplane Tree, trained for the MNIST dataset. The linear coefficients at each split are rearranged and plotted as 28x28 images. We see that 0s and 1s are always assigned to different child nodes, and that larger depth coefficients have more meaningful and complex structure.
  • Figure 5: Coefficients of the linear layer in different tree nodes. We observe how different nodes highlight different features of the images. These differences cause the top row of nodes to split female and male images to the right and left, respectively, while the bottom row splits female and males to the left and right, respectively.
  • ...and 2 more figures