Table of Contents
Fetching ...

Distilling a Neural Network Into a Soft Decision Tree

Nicholas Frosst, Geoffrey Hinton

TL;DR

This paper tackles the challenge of interpreting deep neural networks by distilling their input–output behavior into a soft decision tree, enabling hierarchical decisions that yield explanations alongside predictions. The authors introduce the Hierarchical Mixture of Bigots, train it with gradient-based optimization and regularizers to preserve useful structure, and demonstrate that distillation can improve generalization relative to trees trained directly while sacrificing some accuracy compared to the original neural network. They validate the approach on MNIST and additional datasets, showing improved explainability through path-based reasoning and interpretable filters. The work highlights a practical path to explainable AI: leverage a powerful neural net to train an explicitly interpretable model that retains much of the predictive power with much faster, tractable inference at test time.

Abstract

Deep neural networks have proved to be a very effective way to perform classification tasks. They excel when the input data is high dimensional, the relationship between the input and the output is complicated, and the number of labeled training examples is large. But it is hard to explain why a learned network makes a particular classification decision on a particular test case. This is due to their reliance on distributed hierarchical representations. If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier. We describe a way of using a trained neural net to create a type of soft decision tree that generalizes better than one learned directly from the training data.

Distilling a Neural Network Into a Soft Decision Tree

TL;DR

This paper tackles the challenge of interpreting deep neural networks by distilling their input–output behavior into a soft decision tree, enabling hierarchical decisions that yield explanations alongside predictions. The authors introduce the Hierarchical Mixture of Bigots, train it with gradient-based optimization and regularizers to preserve useful structure, and demonstrate that distillation can improve generalization relative to trees trained directly while sacrificing some accuracy compared to the original neural network. They validate the approach on MNIST and additional datasets, showing improved explainability through path-based reasoning and interpretable filters. The work highlights a practical path to explainable AI: leverage a powerful neural net to train an explicitly interpretable model that retains much of the predictive power with much faster, tractable inference at test time.

Abstract

Deep neural networks have proved to be a very effective way to perform classification tasks. They excel when the input data is high dimensional, the relationship between the input and the output is complicated, and the number of labeled training examples is large. But it is hard to explain why a learned network makes a particular classification decision on a particular test case. This is due to their reliance on distributed hierarchical representations. If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier. We describe a way of using a trained neural net to create a type of soft decision tree that generalizes better than one learned directly from the training data.

Paper Structure

This paper contains 7 sections, 5 equations, 3 figures.

Figures (3)

  • Figure 1: This diagram shows a soft binary decision tree with a single inner node and two leaf nodes.
  • Figure 2: This is a visualization of a soft decision tree of depth 4 trained on MNIST. The images at the inner nodes are the learned filters, and the images at the leaves are visualizations of the learned probability distribution over classes. The final most likely classification at each leaf, as well as the likely classifications at each edge are annotated. If we take for example the right most internal node, we can see that at that level in the tree the potential classifications are only 3 or 8, thus the learned filter is simply learning to distinguish between those two digits. The result is a filter that looks for the presence of two areas that would join the ends of the 3 to make an 8.
  • Figure 3: This is a visualization of the first 2 layers of a soft decision tree trained on the Connect4 data set. From examining the learned filters we can see that the game can be split into two distinct sub types of games - games where the players have placed pieces on the edges of the board, and games where the players have placed pieces in the center of the board.