Table of Contents
Fetching ...

Differentiable Tree Search Network

Dixant Mittal, Wee Sun Lee

TL;DR

D-TSN tackles the data efficiency and generalization gaps of policy learning by embedding a differentiable best-first online search inside a neural network. It jointly optimizes a learned world model with a latent-space search that expands and backs up values, guided by a stochastic expansion policy and variance-reducing REINFORCE-based training with a telescoping-sum trick. The approach yields superior offline-RL performance on Procgen and a grid navigation task, with strong generalization under limited data and deeper search under practical computational budgets. This combination of planning bias, differentiable search, and robust training offers a scalable path to robust decision-making in data-scarce settings, though future work is needed to extend to stochastic environments and continuous actions.

Abstract

In decision-making problems with limited training data, policy functions approximated using deep neural networks often exhibit suboptimal performance. An alternative approach involves learning a world model from the limited data and determining actions through online search. However, the performance is adversely affected by compounding errors arising from inaccuracies in the learned world model. While methods like TreeQN have attempted to address these inaccuracies by incorporating algorithmic inductive biases into the neural network architectures, the biases they introduce are often weak and insufficient for complex decision-making tasks. In this work, we introduce Differentiable Tree Search Network (D-TSN), a novel neural network architecture that significantly strengthens the inductive bias by embedding the algorithmic structure of a best-first online search algorithm. D-TSN employs a learned world model to conduct a fully differentiable online search. The world model is jointly optimized with the search algorithm, enabling the learning of a robust world model and mitigating the effect of prediction inaccuracies. Further, we note that a naive incorporation of best-first search could lead to a discontinuous loss function in the parameter space. We address this issue by adopting a stochastic tree expansion policy, formulating search tree expansion as another decision-making task, and introducing an effective variance reduction technique for the gradient computation. We evaluate D-TSN in an offline-RL setting with a limited training data scenario on Procgen games and grid navigation task, and demonstrate that D-TSN outperforms popular model-free and model-based baselines.

Differentiable Tree Search Network

TL;DR

D-TSN tackles the data efficiency and generalization gaps of policy learning by embedding a differentiable best-first online search inside a neural network. It jointly optimizes a learned world model with a latent-space search that expands and backs up values, guided by a stochastic expansion policy and variance-reducing REINFORCE-based training with a telescoping-sum trick. The approach yields superior offline-RL performance on Procgen and a grid navigation task, with strong generalization under limited data and deeper search under practical computational budgets. This combination of planning bias, differentiable search, and robust training offers a scalable path to robust decision-making in data-scarce settings, though future work is needed to extend to stochastic environments and continuous actions.

Abstract

In decision-making problems with limited training data, policy functions approximated using deep neural networks often exhibit suboptimal performance. An alternative approach involves learning a world model from the limited data and determining actions through online search. However, the performance is adversely affected by compounding errors arising from inaccuracies in the learned world model. While methods like TreeQN have attempted to address these inaccuracies by incorporating algorithmic inductive biases into the neural network architectures, the biases they introduce are often weak and insufficient for complex decision-making tasks. In this work, we introduce Differentiable Tree Search Network (D-TSN), a novel neural network architecture that significantly strengthens the inductive bias by embedding the algorithmic structure of a best-first online search algorithm. D-TSN employs a learned world model to conduct a fully differentiable online search. The world model is jointly optimized with the search algorithm, enabling the learning of a robust world model and mitigating the effect of prediction inaccuracies. Further, we note that a naive incorporation of best-first search could lead to a discontinuous loss function in the parameter space. We address this issue by adopting a stochastic tree expansion policy, formulating search tree expansion as another decision-making task, and introducing an effective variance reduction technique for the gradient computation. We evaluate D-TSN in an offline-RL setting with a limited training data scenario on Procgen games and grid navigation task, and demonstrate that D-TSN outperforms popular model-free and model-based baselines.
Paper Structure (59 sections, 5 theorems, 40 equations, 5 figures, 7 tables, 1 algorithm)

This paper contains 59 sections, 5 theorems, 40 equations, 5 figures, 7 tables, 1 algorithm.

Key Result

Theorem 3.1

Given a set of parameterised modules that are continuous in the parameter space $\theta$, the Q-function computed by fully expanding a search tree to a fixed depth '$d$' by composing these modules and backpropagating the children values using addition and max operations is continuous in the paramete

Figures (5)

  • Figure 1: A sample visualization of Procgen games (left) and Grid Navigation (right).
  • Figure 2: An illustration of the learnable submodules in Differentiable Tree Search Network
  • Figure 3: An illustration of the Expansion Phase in Differentiable Tree Search Network
  • Figure 4: An illustration of the Backup Phase in Differentiable Tree Search Network
  • Figure 5: An illustration of the computation graph construction in Differentiable Tree Search Network

Theorems & Definitions (7)

  • Theorem 3.1
  • Lemma 2.1
  • Lemma 2.2
  • Lemma 2.3
  • proof
  • Theorem 2.4
  • proof