Table of Contents
Fetching ...

Terminating Differentiable Tree Experts

Jonathan Thomm, Michael Hersche, Giacomo Camposampiero, Aleksandar Terzić, Bernhard Schölkopf, Abbas Rahimi

TL;DR

The paper addresses the parameter growth and termination bottlenecks of the Differentiable Tree Machine by introducing Differentiable Tree Experts, which reuse a single Transformer encoder across steps and instantiate a Mixture of Experts to propose operations. A sluggish termination mechanism enables automatic, horizon-aware stopping without oracle guidance, yielding constant parameter scaling with respect to the number of steps. Empirically, DTE and its terminating variant achieve ID and OOD performance on several tree-transformation tasks comparable to the original DTM, while sparse MoE variants offer potential speedups. A key insight is that stronger generalization requires moving beyond the fixed Lisp-operator biases, as highlighted by the tree reversal experiments, which reveal invariance versus genuine OOD generalization limits.

Abstract

We advance the recently proposed neuro-symbolic Differentiable Tree Machine, which learns tree operations using a combination of transformers and Tensor Product Representations. We investigate the architecture and propose two key components. We first remove a series of different transformer layers that are used in every step by introducing a mixture of experts. This results in a Differentiable Tree Experts model with a constant number of parameters for any arbitrary number of steps in the computation, compared to the previous method in the Differentiable Tree Machine with a linear growth. Given this flexibility in the number of steps, we additionally propose a new termination algorithm to provide the model the power to choose how many steps to make automatically. The resulting Terminating Differentiable Tree Experts model sluggishly learns to predict the number of steps without an oracle. It can do so while maintaining the learning capabilities of the model, converging to the optimal amount of steps.

Terminating Differentiable Tree Experts

TL;DR

The paper addresses the parameter growth and termination bottlenecks of the Differentiable Tree Machine by introducing Differentiable Tree Experts, which reuse a single Transformer encoder across steps and instantiate a Mixture of Experts to propose operations. A sluggish termination mechanism enables automatic, horizon-aware stopping without oracle guidance, yielding constant parameter scaling with respect to the number of steps. Empirically, DTE and its terminating variant achieve ID and OOD performance on several tree-transformation tasks comparable to the original DTM, while sparse MoE variants offer potential speedups. A key insight is that stronger generalization requires moving beyond the fixed Lisp-operator biases, as highlighted by the tree reversal experiments, which reveal invariance versus genuine OOD generalization limits.

Abstract

We advance the recently proposed neuro-symbolic Differentiable Tree Machine, which learns tree operations using a combination of transformers and Tensor Product Representations. We investigate the architecture and propose two key components. We first remove a series of different transformer layers that are used in every step by introducing a mixture of experts. This results in a Differentiable Tree Experts model with a constant number of parameters for any arbitrary number of steps in the computation, compared to the previous method in the Differentiable Tree Machine with a linear growth. Given this flexibility in the number of steps, we additionally propose a new termination algorithm to provide the model the power to choose how many steps to make automatically. The resulting Terminating Differentiable Tree Experts model sluggishly learns to predict the number of steps without an oracle. It can do so while maintaining the learning capabilities of the model, converging to the optimal amount of steps.
Paper Structure (15 sections, 5 equations, 5 figures, 4 tables)

This paper contains 15 sections, 5 equations, 5 figures, 4 tables.

Figures (5)

  • Figure 1: The DTM architecture. In each step, a new tree superposition is generated (in TPR) using a different transformer encoder layer for each step. The instruction probabilities are predicted by the transformer encoder layer. car, cdr, and cons are the three Lisp operations.
  • Figure 2: The architecture of the DTE. In each step, a new tree is generated using the same model. Our transformer encoder layer is now a Mixture of Experts (MoE) with the router itself being a combination of a transformer encoder layer and a linear map. The router chooses the expert weights, which then are used to weigh the outputs of each expert. In our sparse MoE ablations, only the top 4 experts are activated.
  • Figure 3: Cases of the sluggish termination losses. The two arrows indicate the labels of the termination predictors in each case. The cases are given by whether the predictors are below or above the yellow confidence threshold. The orange dots show where the main model loss is and the two (relatively small) residual losses. The purple dot is the best local termination (e.g., best loss with some step penalty). The green predictor is called the "explorer", the blue one the "damper", as it will start to follow the explorer when the explorer becomes confident and otherwise stays where it is.
  • Figure 4: Examples of the Car-Cdr-Seq, Passive$\leftrightarrow$Logical, and Active$\leftrightarrow$Logical dataset dtm. The model has to transform a source tree to the target tree. For the Passive$\leftrightarrow$Logical case we show the intermediate trees that the model could produce to get to the target tree. Moreover, we show an example of lexical generalization that uses unseen adjectives (in this case "funny"), as well as one for the structural generalization test set that adds additional adjectives.
  • Figure 5: Visualization of how the DTM and our (T)DTE can solve our novel tree reversal task. As shown, with the three operations cdr, car, and cons, reversing a tree requires several steps and more with growing tree size since every child of a branching node needs to be extracted to assemble the tree in reverse order afterward.