Table of Contents
Fetching ...

Learning a Decision Tree Algorithm with Transformers

Yufan Zhuang, Liyuan Liu, Chandan Singh, Jingbo Shang, Jianfeng Gao

TL;DR

MetaTree presents a transformer-based approach to directly generate decision trees by meta-learning from both greedy CART and globally optimized GOSDT trees. The model employs a tabular Transformer with alternating row/column self-attention and learnable positional bias, trained via a two-phase curriculum and Gaussian-smoothed supervision to imitate optimal splits while preserving generalization. Empirical results show strong generalization to 91 unseen datasets and the ability to produce deeper trees than seen during training, with lower empirical variance compared to baselines. The work highlights the capacity of deep models to learn algorithmic strategies and adaptively switch between greedy and global planning, offering a path toward differentiable model construction and potential LLM-guided decision processes.

Abstract

Decision trees are renowned for their ability to achieve high predictive performance while remaining interpretable, especially on tabular data. Traditionally, they are constructed through recursive algorithms, where they partition the data at every node in a tree. However, identifying a good partition is challenging, as decision trees optimized for local segments may not yield global generalization. To address this, we introduce MetaTree, a transformer-based model trained via meta-learning to directly produce strong decision trees. Specifically, we fit both greedy decision trees and globally optimized decision trees on a large number of datasets, and train MetaTree to produce only the trees that achieve strong generalization performance. This training enables MetaTree to emulate these algorithms and intelligently adapt its strategy according to the context, thereby achieving superior generalization performance.

Learning a Decision Tree Algorithm with Transformers

TL;DR

MetaTree presents a transformer-based approach to directly generate decision trees by meta-learning from both greedy CART and globally optimized GOSDT trees. The model employs a tabular Transformer with alternating row/column self-attention and learnable positional bias, trained via a two-phase curriculum and Gaussian-smoothed supervision to imitate optimal splits while preserving generalization. Empirical results show strong generalization to 91 unseen datasets and the ability to produce deeper trees than seen during training, with lower empirical variance compared to baselines. The work highlights the capacity of deep models to learn algorithmic strategies and adaptively switch between greedy and global planning, offering a path toward differentiable model construction and potential LLM-guided decision processes.

Abstract

Decision trees are renowned for their ability to achieve high predictive performance while remaining interpretable, especially on tabular data. Traditionally, they are constructed through recursive algorithms, where they partition the data at every node in a tree. However, identifying a good partition is challenging, as decision trees optimized for local segments may not yield global generalization. To address this, we introduce MetaTree, a transformer-based model trained via meta-learning to directly produce strong decision trees. Specifically, we fit both greedy decision trees and globally optimized decision trees on a large number of datasets, and train MetaTree to produce only the trees that achieve strong generalization performance. This training enables MetaTree to emulate these algorithms and intelligently adapt its strategy according to the context, thereby achieving superior generalization performance.
Paper Structure (42 sections, 4 equations, 11 figures, 7 tables)

This paper contains 42 sections, 4 equations, 11 figures, 7 tables.

Figures (11)

  • Figure 1: MetaTree Methodology. The creation of a decision tree, depicted in (a), entails recursive MetaTree calls. MetaTree only assesses the current state for its decision-making. (b) shows MetaTree's architecture, in which the tabular input ($n$ data points of $m$ features) is embedded in a representation space, processed with row and column attention at each layer, and the output is a one-hot mask indicating the splitting feature $j$ and threshold $\mathrm{X}_{i,j}$. (c) illustrates MetaTree's two-phase learning curriculum: in the first phase, the focus is exclusively on learning from the optimized GOSDT trees, to closely emulate the behavior of GOSDT algorithm. Then in the second phase, the training process incorporates data from both the GOSDT and CART trees, generating the ones that have better generalization capabilities.
  • Figure 2: MetaTree demonstrates strong generalization on real-world datasets. MetaTree generalizes well to 91 held-out datasets for (A) depth-2 trees (B) depth-3 trees and (C) depth-4 trees, despite only being trained to produce depth-2 trees. MetaTree also generalizes well to the 13 Tree-of-prompts datasets for (D) depth-2 trees (E) depth-3 trees (F) depth-4 trees, which requires constructing a tree to steer a large language model morris2023tree. Each plot shows the average test accuracy for tree ensembles of size {1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, with error bars indicating the standard deviation.
  • Figure 3: We show that MetaTree learns to adapt its splitting strategy for better generalization. (a), (b) The figures demonstrate MetaTree's tendency to select the more effective generalization strategy: opting for the greedy algorithm CART when CART performs better and the optimal algorithm GOSDT when GOSDT is the better choice. (c) We also conduct regression analysis showing that MetaTree's algorithmic preference positively correlates with the better generalizing algorithm's performance.
  • Figure 4: Empirical bias-variance decomposition for MetaTree, GOSDT, and CART on 91 left-out datasets with 100 repetitions is shown in (a). MetaTree has significantly lower variance and slightly smaller bias as compared to GOSDT and CART. (b) We compare accuracy delta (y-axis) for each dataset (x-axis) when fitting a single tree. The delta is the change compared to the mean accuracy of CART and GOSDT.
  • Figure A1: Exploratory analysis of MetaTree. (a), (b) We examined MetaTree's performance in a controlled XOR setting with various noise levels and problem difficulty. We show an illustration of MetaTree solving Level 1 XOR and generalizing to Level 2 XOR while greedy algorithms like CART are unable to solve these. (c), (d) We probe the decision-making process of MetaTree over the Transformer layers, we found out that MetaTree can very often generate the final split early on.
  • ...and 6 more figures