Learning a Decision Tree Algorithm with Transformers
Yufan Zhuang, Liyuan Liu, Chandan Singh, Jingbo Shang, Jianfeng Gao
TL;DR
MetaTree presents a transformer-based approach to directly generate decision trees by meta-learning from both greedy CART and globally optimized GOSDT trees. The model employs a tabular Transformer with alternating row/column self-attention and learnable positional bias, trained via a two-phase curriculum and Gaussian-smoothed supervision to imitate optimal splits while preserving generalization. Empirical results show strong generalization to 91 unseen datasets and the ability to produce deeper trees than seen during training, with lower empirical variance compared to baselines. The work highlights the capacity of deep models to learn algorithmic strategies and adaptively switch between greedy and global planning, offering a path toward differentiable model construction and potential LLM-guided decision processes.
Abstract
Decision trees are renowned for their ability to achieve high predictive performance while remaining interpretable, especially on tabular data. Traditionally, they are constructed through recursive algorithms, where they partition the data at every node in a tree. However, identifying a good partition is challenging, as decision trees optimized for local segments may not yield global generalization. To address this, we introduce MetaTree, a transformer-based model trained via meta-learning to directly produce strong decision trees. Specifically, we fit both greedy decision trees and globally optimized decision trees on a large number of datasets, and train MetaTree to produce only the trees that achieve strong generalization performance. This training enables MetaTree to emulate these algorithms and intelligently adapt its strategy according to the context, thereby achieving superior generalization performance.
