Table of Contents
Fetching ...

Learning Tree-Based Models with Gradient Descent

Sascha Marton

Abstract

Tree-based models are widely recognized for their interpretability and have proven effective in various application domains, particularly in high-stakes domains. However, learning decision trees (DTs) poses a significant challenge due to their combinatorial complexity and discrete, non-differentiable nature. As a result, traditional methods such as CART, which rely on greedy search procedures, remain the most widely used approaches. These methods make locally optimal decisions at each node, constraining the search space and often leading to suboptimal tree structures. Additionally, their demand for custom training methods precludes a seamless integration into modern machine learning (ML) approaches. In this thesis, we propose a novel method for learning hard, axis-aligned DTs through gradient descent. Our approach utilizes backpropagation with a straight-through operator on a dense DT representation, enabling the joint optimization of all tree parameters, thereby addressing the two primary limitations of traditional DT algorithms. First, gradient-based training is not constrained by the sequential selection of locally optimal splits but, instead, jointly optimizes all tree parameters. Second, by leveraging gradient descent for optimization, our approach seamlessly integrates into existing ML approaches e.g., for multimodal and reinforcement learning tasks, which inherently rely on gradient descent. These advancements allow us to achieve state-of-the-art results across multiple domains, including interpretable DTs rees for small tabular datasets, advanced models for complex tabular data, multimodal learning, and interpretable reinforcement learning without information loss. By bridging the gap between DTs and gradient-based optimization, our method significantly enhances the performance and applicability of tree-based models across various ML domains.

Learning Tree-Based Models with Gradient Descent

Abstract

Tree-based models are widely recognized for their interpretability and have proven effective in various application domains, particularly in high-stakes domains. However, learning decision trees (DTs) poses a significant challenge due to their combinatorial complexity and discrete, non-differentiable nature. As a result, traditional methods such as CART, which rely on greedy search procedures, remain the most widely used approaches. These methods make locally optimal decisions at each node, constraining the search space and often leading to suboptimal tree structures. Additionally, their demand for custom training methods precludes a seamless integration into modern machine learning (ML) approaches. In this thesis, we propose a novel method for learning hard, axis-aligned DTs through gradient descent. Our approach utilizes backpropagation with a straight-through operator on a dense DT representation, enabling the joint optimization of all tree parameters, thereby addressing the two primary limitations of traditional DT algorithms. First, gradient-based training is not constrained by the sequential selection of locally optimal splits but, instead, jointly optimizes all tree parameters. Second, by leveraging gradient descent for optimization, our approach seamlessly integrates into existing ML approaches e.g., for multimodal and reinforcement learning tasks, which inherently rely on gradient descent. These advancements allow us to achieve state-of-the-art results across multiple domains, including interpretable DTs rees for small tabular datasets, advanced models for complex tabular data, multimodal learning, and interpretable reinforcement learning without information loss. By bridging the gap between DTs and gradient-based optimization, our method significantly enhances the performance and applicability of tree-based models across various ML domains.
Paper Structure (330 sections, 70 equations, 50 figures, 80 tables, 4 algorithms)

This paper contains 330 sections, 70 equations, 50 figures, 80 tables, 4 algorithms.

Figures (50)

  • Figure 1: Exemplary DT. This is an example of a simple DT trained on a simulated dataset to predict whether a bank should approve a credit application. The decision process begins by checking if the Loan Status is greater than or equal to $0.5$, indicating whether the customer already holds a personal loan with the bank. If the customer already has a loan ($1$), the bank should not grant additional credit. If the customer does not hold a loan ($0$), the model proceeds to assess the customer's Balance, comparing it to a threshold of $10{,}354.5$. If the balance is below the threshold, the credit application is denied. For higher balances, the model further evaluates the applicant’s Age: If the customer is older than 26.5 years, the credit is approved and otherwise, it is denied.
  • Figure 2: Decision Boundary Comparison. This figure illustrates the decision boundaries, i.e., the hypersurface that separates different classes in the input space based on the learned function, of a DT (A) and an NN (B) on a simulated credit dataset. The decision boundary of the DT in Figure \ref{['fig:dt_example_credit']} clearly partitions the data along axis-aligned splits. In contrast, the NN exhibits a more flexible decision boundary, which lacks a similarly interpretable structure. Notably, the DT effectively captures the underlying patterns in the data, such as the distinct separation between Loan Status values of 0 and 1. The NN, however, introduces artificial variance by treating values between 0 and 1 differently, even though such values are not present in the dataset. This highlights the advantageous inductive bias of axis-aligned splits, which efficiently capture the relevant structure of the Loan Status variable.
  • Figure 3: Greedy vs. Gradient-Based DT. Two DTs trained on the Echocardiogram dataset. The CART DT (A) sequentially performs only locally optimal splits, thereby constraining the search space, whereas the non-greedy algorithm (B) explores the entire search space, enabling the discovery of a solution that achieves significantly better performance.
  • Figure 4: Visualization of a DT's Decision-Making Process. (A) shows how the DT partitions the 2D feature space (Age and Fare) into regions, with each region corresponding to a prediction of whether a passenger survived. (B) shows the corresponding tree structure, detailing the decision rules at each node.
  • Figure 5: Decision Boundaries of Axis-Aligned vs. Oblique DT on Reduced Titanic Dataset. The oblique DT's decision boundaries (B) reflect the ability to form linear combinations of features, resulting in more flexible boundaries of arbitrary shapes. This flexibility is evident in the oblique tree's decision boundaries, which are not limited to being parallel to the feature axes, unlike the axis-aligned tree (A).
  • ...and 45 more figures