Table of Contents
Fetching ...

Decision Trees That Remember: Gradient-Based Learning of Recurrent Decision Trees with Memory

Sascha Marton, Moritz Schneider

TL;DR

This work addresses the challenge of modeling long-range temporal dependencies with interpretable decision structures by introducing ReMeDe Trees, a gradient-trained recurrent decision-tree architecture with an internal memory $M \in \mathbb{R}^{n_m}$. The model extends GradTree by incorporating memory into the input-state space $\tilde{X} = X \times M$, uses memory gating to update the hidden state, and learns both routing decisions and memory dynamics via backpropagation through time. Evaluation on five synthetic PoC tasks demonstrates perfect test accuracy and compact tree sizes, with ReMeDe Trees able to learn recurrent behavior and effectively manipulate internal memory, matching LSTM baselines in these scenarios. The approach promises a bridge between the interpretability of axis-aligned decision trees and the sequence modeling power of recurrent architectures, with potential for integration into ensembles and broader time-series applications.

Abstract

Neural architectures such as Recurrent Neural Networks (RNNs), Transformers, and State-Space Models have shown great success in handling sequential data by learning temporal dependencies. Decision Trees (DTs), on the other hand, remain a widely used class of models for structured tabular data but are typically not designed to capture sequential patterns directly. Instead, DT-based approaches for time-series data often rely on feature engineering, such as manually incorporating lag features, which can be suboptimal for capturing complex temporal dependencies. To address this limitation, we introduce ReMeDe Trees, a novel recurrent DT architecture that integrates an internal memory mechanism, similar to RNNs, to learn long-term dependencies in sequential data. Our model learns hard, axis-aligned decision rules for both output generation and state updates, optimizing them efficiently via gradient descent. We provide a proof-of-concept study on synthetic benchmarks to demonstrate the effectiveness of our approach.

Decision Trees That Remember: Gradient-Based Learning of Recurrent Decision Trees with Memory

TL;DR

This work addresses the challenge of modeling long-range temporal dependencies with interpretable decision structures by introducing ReMeDe Trees, a gradient-trained recurrent decision-tree architecture with an internal memory . The model extends GradTree by incorporating memory into the input-state space , uses memory gating to update the hidden state, and learns both routing decisions and memory dynamics via backpropagation through time. Evaluation on five synthetic PoC tasks demonstrates perfect test accuracy and compact tree sizes, with ReMeDe Trees able to learn recurrent behavior and effectively manipulate internal memory, matching LSTM baselines in these scenarios. The approach promises a bridge between the interpretability of axis-aligned decision trees and the sequence modeling power of recurrent architectures, with potential for integration into ensembles and broader time-series applications.

Abstract

Neural architectures such as Recurrent Neural Networks (RNNs), Transformers, and State-Space Models have shown great success in handling sequential data by learning temporal dependencies. Decision Trees (DTs), on the other hand, remain a widely used class of models for structured tabular data but are typically not designed to capture sequential patterns directly. Instead, DT-based approaches for time-series data often rely on feature engineering, such as manually incorporating lag features, which can be suboptimal for capturing complex temporal dependencies. To address this limitation, we introduce ReMeDe Trees, a novel recurrent DT architecture that integrates an internal memory mechanism, similar to RNNs, to learn long-term dependencies in sequential data. Our model learns hard, axis-aligned decision rules for both output generation and state updates, optimizing them efficiently via gradient descent. We provide a proof-of-concept study on synthetic benchmarks to demonstrate the effectiveness of our approach.

Paper Structure

This paper contains 26 sections, 20 equations, 2 figures, 2 tables.

Figures (2)

  • Figure 1: Minimal Recurrent Decision Tree Example This figure shows an exemplary ReMeDe tree applied to a sign recognition task. The task is to memorize the sign of $x \in (-0.5,0.5)$ at the first position and predict it (-1 or 1) when a trigger value (1) appears, while intermediate positions hold zeros plus small noise. The figure depicts the minimal ReMeDe tree solving this task. At the root node, the tree checks whether the trigger occurs. If not (left branch), there are two cases: If the hidden state is zero, it updates based on input, adopting the sign of the entry; otherwise, it remains unchanged. If the trigger occurs (right branch), the tree splits on the hidden state to predict the sign of the first value: negative for a negative hidden state, positive otherwise.
  • Figure 2: ReMeDe Tree Update Visualization This figure shows an ReMeDe tree trained to a sign recognition task. The task is to memorize the sign of $x \in (-0.5,0.5)$ at the first position and predict it (-1 or 1) when a trigger value (1) appears, while intermediate positions hold zeros plus small noise.