Table of Contents
Fetching ...

GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent

Sascha Marton, Stefan Lüdtke, Christian Bartelt, Heiner Stuckenschmidt

TL;DR

GradTree introduces a gradient-based framework for learning hard axis-aligned decision trees by applying backpropagation to a dense, differentiable tree representation. It addresses non-differentiability at split decisions with a straight-through estimator and entmax-based relaxations, enabling joint optimization of all tree parameters. Empirical results show superior binary-classification performance and competitive multi-class results, coupled with smaller pruned trees and robustness to overfitting, while maintaining scalable training on high-dimensional data. The approach offers flexibility for custom loss functions and has potential to extend to axis-aligned tree ensembles learned end-to-end with gradient-based optimization.

Abstract

Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The method is available under: https://github.com/s-marton/GradTree

GradTree: Learning Axis-Aligned Decision Trees with Gradient Descent

TL;DR

GradTree introduces a gradient-based framework for learning hard axis-aligned decision trees by applying backpropagation to a dense, differentiable tree representation. It addresses non-differentiability at split decisions with a straight-through estimator and entmax-based relaxations, enabling joint optimization of all tree parameters. Empirical results show superior binary-classification performance and competitive multi-class results, coupled with smaller pruned trees and robustness to overfitting, while maintaining scalable training on high-dimensional data. The approach offers flexibility for custom loss functions and has potential to extend to axis-aligned tree ensembles learned end-to-end with gradient-based optimization.

Abstract

Decision Trees (DTs) are commonly used for many machine learning tasks due to their high degree of interpretability. However, learning a DT from data is a difficult optimization problem, as it is non-convex and non-differentiable. Therefore, common approaches learn DTs using a greedy growth algorithm that minimizes the impurity locally at each internal node. Unfortunately, this greedy procedure can lead to inaccurate trees. In this paper, we present a novel approach for learning hard, axis-aligned DTs with gradient descent. The proposed method uses backpropagation with a straight-through operator on a dense DT representation, to jointly optimize all tree parameters. Our approach outperforms existing methods on binary classification benchmarks and achieves competitive results for multi-class tasks. The method is available under: https://github.com/s-marton/GradTree
Paper Structure (36 sections, 6 equations, 2 figures, 16 tables, 2 algorithms)

This paper contains 36 sections, 6 equations, 2 figures, 16 tables, 2 algorithms.

Figures (2)

  • Figure 1: Greedy vs. Gradient-Based DT. Two DTs trained on the Echocardiogram dataset. The CART DT (left) makes only locally optimal splits, while GradTree (right) jointly optimizes all parameters, leading to significantly better performance.
  • Figure 2: Standard vs. Dense DT Representation. Comparison of a standard and the equivalent dense representation for an exemplary DT with depth $2$ and a dataset with $3$ variables and $2$ classes. Here, $\mathbb{S}_{\text{lh}}$ stands for $\mathbb{S}_\text{logistic\_hard}$ (Equation \ref{['eq:split_sigmoid_round']}).