Table of Contents
Fetching ...

Learning Decision Trees as Amortized Structure Inference

Mohammed Mahfoud, Ghait Boukachab, Michał Koziarski, Alex Hernandez-Garcia, Stefan Bauer, Yoshua Bengio, Nikolay Malkin

TL;DR

This paper addresses the challenge of learning predictive models for tabular data by reframing decision tree construction as a sequential structure-inference problem. It introduces DT-GFN, which uses a Generative Flow Network (GFlowNet) to amortize sampling from the Bayesian posterior over decision trees, enabling diverse, high-posterior trees to be generated as part of an interpretable ensemble. The method provides a principled prior over tree complexity, a TB-based training objective, and an MDP-based architecture that yields trees whose predictions can be combined in Bayesian ensembles with scalable performance. Empirically, DT-GFN outperforms state-of-the-art single-tree and some deep-learning baselines on standard tabular benchmarks, demonstrates robustness to distribution shifts and anomaly detection, and exhibits consistent scaling as ensemble size grows, while maintaining short, interpretable tree descriptions.

Abstract

Building predictive models for tabular data presents fundamental challenges, notably in scaling consistently, i.e., more resources translating to better performance, and generalizing systematically beyond the training data distribution. Designing decision tree models remains especially challenging given the intractably large search space, and most existing methods rely on greedy heuristics, while deep learning inductive biases expect a temporal or spatial structure not naturally present in tabular data. We propose a hybrid amortized structure inference approach to learn predictive decision tree ensembles given data, formulating decision tree construction as a sequential planning problem. We train a deep reinforcement learning (GFlowNet) policy to solve this problem, yielding a generative model that samples decision trees from the Bayesian posterior. We show that our approach, DT-GFN, outperforms state-of-the-art decision tree and deep learning methods on standard classification benchmarks derived from real-world data, robustness to distribution shifts, and anomaly detection, all while yielding interpretable models with shorter description lengths. Samples from the trained DT-GFN model can be ensembled to construct a random forest, and we further show that the performance of scales consistently in ensemble size, yielding ensembles of predictors that continue to generalize systematically.

Learning Decision Trees as Amortized Structure Inference

TL;DR

This paper addresses the challenge of learning predictive models for tabular data by reframing decision tree construction as a sequential structure-inference problem. It introduces DT-GFN, which uses a Generative Flow Network (GFlowNet) to amortize sampling from the Bayesian posterior over decision trees, enabling diverse, high-posterior trees to be generated as part of an interpretable ensemble. The method provides a principled prior over tree complexity, a TB-based training objective, and an MDP-based architecture that yields trees whose predictions can be combined in Bayesian ensembles with scalable performance. Empirically, DT-GFN outperforms state-of-the-art single-tree and some deep-learning baselines on standard tabular benchmarks, demonstrates robustness to distribution shifts and anomaly detection, and exhibits consistent scaling as ensemble size grows, while maintaining short, interpretable tree descriptions.

Abstract

Building predictive models for tabular data presents fundamental challenges, notably in scaling consistently, i.e., more resources translating to better performance, and generalizing systematically beyond the training data distribution. Designing decision tree models remains especially challenging given the intractably large search space, and most existing methods rely on greedy heuristics, while deep learning inductive biases expect a temporal or spatial structure not naturally present in tabular data. We propose a hybrid amortized structure inference approach to learn predictive decision tree ensembles given data, formulating decision tree construction as a sequential planning problem. We train a deep reinforcement learning (GFlowNet) policy to solve this problem, yielding a generative model that samples decision trees from the Bayesian posterior. We show that our approach, DT-GFN, outperforms state-of-the-art decision tree and deep learning methods on standard classification benchmarks derived from real-world data, robustness to distribution shifts, and anomaly detection, all while yielding interpretable models with shorter description lengths. Samples from the trained DT-GFN model can be ensembled to construct a random forest, and we further show that the performance of scales consistently in ensemble size, yielding ensembles of predictors that continue to generalize systematically.

Paper Structure

This paper contains 57 sections, 22 equations, 5 figures, 12 tables, 1 algorithm.

Figures (5)

  • Figure 1: Learning a decision tree as decision-making in a Markov decision process (MDP)$\mathcal{M}$. At each step of construction, the data is split by a decision threshold on one of the features (right for True [T] or left for False [F]). We start from an empty source state $T_0$ with no decision rules and move through $\mathcal{M}$ by taking some action $a$ corresponding to finding a decision rule $(\ell, f, t)$, i.e., split data at leaf $\ell$ on feature $f$ with threshold $t$. At each state of $\mathcal{M}$, $a$ can either be a valid action, i.e., resulting in a valid split, or invalid one, resulting in an invalid/redundant split. At each state of $\mathcal{M}$, we have the choice to stop sampling, in which case the resulting tree is a terminating state $\color{RoyalBlue}{\boldsymbol{\perp}}$. The reward function $\mathcal{R}$ can be computed at any valid state.
  • Figure 2: Distribution shift in-distribution/out-of-distribution plots with ablations on ensemble sizes $[100, 500, 1000]$ for tree-based methods. Visualization of distribution shifts caused by interventions on (a) BMI features and (b) Age features, with the symbol sizes indicating the ensemble sizes.
  • Figure 3: Varying the number of features in a hidden XOR task where the label is an XOR operation between two features. Noise features are chosen to be either binary (left) or real (right). All datasets contain 1000 samples.
  • Figure 4: DT-GFN scaling with ensemble size in $[10, 100, 200, 500, 1000]$ and allocated time/compute budget in [0 seconds, 10 seconds, 20 seconds]. Experiment performed on the Iris dataset and results are averaged over data split seeds $[1,2,3,4,5]$. We show consistent scaling across both ensemble sizes and training time/compute resources. Inference costs are negligible as shown in \ref{['tab:cost_ensemble_sizes']}.
  • Figure 5: Systematic increase in generalization accuracy both in-distribution and out-of-distribution in the ensemble size for tree-based methods.