Table of Contents
Fetching ...

On the Power of Decision Trees in Auto-Regressive Language Modeling

Yulu Gan, Tomer Galanti, Tomaso Poggio, Eran Malach

TL;DR

This research reveals the unique computational abilities of ARDTs, aiming to broaden the architectural diversity in language model development.

Abstract

Originally proposed for handling time series data, Auto-regressive Decision Trees (ARDTs) have not yet been explored for language modeling. This paper delves into both the theoretical and practical applications of ARDTs in this new context. We theoretically demonstrate that ARDTs can compute complex functions, such as simulating automata, Turing machines, and sparse circuits, by leveraging "chain-of-thought" computations. Our analysis provides bounds on the size, depth, and computational efficiency of ARDTs, highlighting their surprising computational power. Empirically, we train ARDTs on simple language generation tasks, showing that they can learn to generate coherent and grammatically correct text on par with a smaller Transformer model. Additionally, we show that ARDTs can be used on top of transformer representations to solve complex reasoning tasks. This research reveals the unique computational abilities of ARDTs, aiming to broaden the architectural diversity in language model development.

On the Power of Decision Trees in Auto-Regressive Language Modeling

TL;DR

This research reveals the unique computational abilities of ARDTs, aiming to broaden the architectural diversity in language model development.

Abstract

Originally proposed for handling time series data, Auto-regressive Decision Trees (ARDTs) have not yet been explored for language modeling. This paper delves into both the theoretical and practical applications of ARDTs in this new context. We theoretically demonstrate that ARDTs can compute complex functions, such as simulating automata, Turing machines, and sparse circuits, by leveraging "chain-of-thought" computations. Our analysis provides bounds on the size, depth, and computational efficiency of ARDTs, highlighting their surprising computational power. Empirically, we train ARDTs on simple language generation tasks, showing that they can learn to generate coherent and grammatically correct text on par with a smaller Transformer model. Additionally, we show that ARDTs can be used on top of transformer representations to solve complex reasoning tasks. This research reveals the unique computational abilities of ARDTs, aiming to broaden the architectural diversity in language model development.
Paper Structure (21 sections, 5 theorems, 6 equations, 5 figures, 6 tables, 1 algorithm)

This paper contains 21 sections, 5 theorems, 6 equations, 5 figures, 6 tables, 1 algorithm.

Key Result

Theorem 3

Let ${\mathbb{D}} = \Sigma \cup Q \cup \{\left\langle\mathrm{PAD}\right\rangle\}$. Then, ${\mathcal{F}}^\mathrm{Aut}_n$ can be simulated by ARDTs of size $O\left(\left \lvert {\mathbb{D}} \right \rvert^2\right)$, depth $O(\log \left \lvert {\mathbb{D}} \right \rvert)$ and context length $L \ge n$, i

Figures (5)

  • Figure 1: (a) An example of story continuation generated by our Auto-Regressive Decision Trees. We use decision trees and, remarkably, attain results comparable to Transformer-based models in terms of linguistic fluency. (b) The decision process of the decision trees. We visualize part of the tree ensemble, and can observe which word is most relevant for the splitting rule at each node.
  • Figure 2: The Pipeline of Our Method.(a) Training. First, we employ a Word2Vec model to convert words into embeddings. Next, we utilize a sliding window approach to construct a dataset for training decision trees. Within this window, we performed a weighted average calculation, and the following token after the window was used as the label. (b) Inference. We use our trained Decision Trees for the purpose of next-token prediction.
  • Figure 3: t-SNE vandermaaten2013barneshutsne visualization of 20 cluster centers. We selected 20 cluster centers and display 4 words closest to the cluster centers.
  • Figure 4: Track the decision-making process within the decision trees. We use 'Lily and Tom loved to play together, and they found' as an the input prompt and generate the next word using our ARDTs. We visualize part of the process within the decision tree. Specifically, we visualized 31 nodes of the first decision tree.
  • Figure 5: Feature Importance. We present the feature importance of the top 20 words most closely associated with each cluster, based on their average gain.

Theorems & Definitions (12)

  • Definition 2
  • Theorem 3
  • Theorem 4
  • proof : Proof of \ref{['thm:lower_bound']}
  • Theorem 6
  • Theorem 7
  • proof : Proof of Theorem \ref{['thm:circuits']}
  • Definition 9
  • Lemma 10
  • proof
  • ...and 2 more