Table of Contents
Fetching ...

Policy Gradient with Tree Expansion

Gal Dalal, Assaf Hallak, Gugan Thoppe, Shie Mannor, Gal Chechik

TL;DR

It is proved that the closer the induced transitions are to being state-independent, the stronger the variance decay, and the resulting gradient bias diminishes with the approximation error while retaining the same variance reduction.

Abstract

Policy gradient methods are notorious for having a large variance and high sample complexity. To mitigate this, we introduce SoftTreeMax -- a generalization of softmax that employs planning. In SoftTreeMax, we extend the traditional logits with the multi-step discounted cumulative reward, topped with the logits of future states. We analyze SoftTreeMax and explain how tree expansion helps to reduce its gradient variance. We prove that the variance depends on the chosen tree-expansion policy. Specifically, we show that the closer the induced transitions are to being state-independent, the stronger the variance decay. With approximate forward models, we prove that the resulting gradient bias diminishes with the approximation error while retaining the same variance reduction. Ours is the first result to bound the gradient bias for an approximate model. In a practical implementation of SoftTreeMax, we utilize a parallel GPU-based simulator for fast and efficient tree expansion. Using this implementation in Atari, we show that SoftTreeMax reduces the gradient variance by three orders of magnitude. This leads to better sample complexity and improved performance compared to distributed PPO.

Policy Gradient with Tree Expansion

TL;DR

It is proved that the closer the induced transitions are to being state-independent, the stronger the variance decay, and the resulting gradient bias diminishes with the approximation error while retaining the same variance reduction.

Abstract

Policy gradient methods are notorious for having a large variance and high sample complexity. To mitigate this, we introduce SoftTreeMax -- a generalization of softmax that employs planning. In SoftTreeMax, we extend the traditional logits with the multi-step discounted cumulative reward, topped with the logits of future states. We analyze SoftTreeMax and explain how tree expansion helps to reduce its gradient variance. We prove that the variance depends on the chosen tree-expansion policy. Specifically, we show that the closer the induced transitions are to being state-independent, the stronger the variance decay. With approximate forward models, we prove that the resulting gradient bias diminishes with the approximation error while retaining the same variance reduction. Ours is the first result to bound the gradient bias for an approximate model. In a practical implementation of SoftTreeMax, we utilize a parallel GPU-based simulator for fast and efficient tree expansion. Using this implementation in Atari, we show that SoftTreeMax reduces the gradient variance by three orders of magnitude. This leads to better sample complexity and improved performance compared to distributed PPO.
Paper Structure (30 sections, 13 theorems, 91 equations, 6 figures, 2 algorithms)

This paper contains 30 sections, 13 theorems, 91 equations, 6 figures, 2 algorithms.

Key Result

Lemma 4.1

Let $\nabla_\theta\log \pi_\theta(\cdot|s) \in \mathbb{R}^{A \times \dim(\theta)}$ be a matrix whose $a$-th row is $\nabla_\theta\log \pi_\theta(a|s)^\top$. For any parametric policy $\pi_\theta$ and function $Q^{\pi_\theta}:\mathcal{S}\times\mathcal{A} \rightarrow \mathbb{R},$

Figures (6)

  • Figure 1: A comparison of the empirical PG variance and our bound for E-SoftTreeMax on randomly drawn MDPs. We present three cases for $P^{\pi_b}:$ (i) close to uniform, (ii) drawn randomly, and (iii) close to a permutation matrix. This experiment verifies the optimal and worse-case rate decay cases. The variance bounds here are taken from Theorem \ref{['thm:rate_result2']} where we substitute $\alpha=|\lambda_2(P^{\pi_b})|.$ To account for the constants, we match the values for the first point in $d=1.$
  • Figure 2: SoftTreeMax policy. Our exhaustive parallel tree expansion iterates on all actions at each state up to depth $d$ ($=2$ here). The leaf state of every trajectory is used as input to the policy network. The output is then added to the trajectory's cumulative reward as described in \ref{['eq:logit']}. I.e., instead of the standard softmax logits, we add the cumulative discounted reward to the policy network output. This policy is differentiable and can be easily integrated into any PG algorithm. In this work, we build on PPO and use its loss function to train the policy network.
  • Figure 3: Reward and Gradient variance: GPU SoftTreeMax (single worker) vs PPO ($\bf{256}$ GPU workers). The blue reward plots show the average of $50$ evaluation episodes. The red variance plots show the average gradient variance of the corresponding training runs, averaged over five seeds. The dashed lines represent the same for PPO. Note that the variance y-axis is in log-scale.
  • Figure 4: A diagram of the tree expansion used by SoftTreeMax. In every step, the states in the current level of the tree are duplicated and concatenated with each possible action. The resulting state-action pairs are then fed as a batch to the GPU simulator to generate the next level of states. Finally, the states of the last level $d$ are inserted into the neural network $W_\theta$ and the logits are computed using the corresponding rewards along each trajectory.
  • Figure 5: Training curves: GPU SoftTreeMax (single worker) vs PPO ($\bf{256}$ GPU workers). The plots show average reward and standard deviation over 5 seeds. The x-axis is the wall-clock time. The runs ended after one week with varying number of time-steps. The training curves correspond to the evaluation runs in Figure \ref{['fig:variance_curves']}.
  • ...and 1 more figures

Theorems & Definitions (30)

  • Lemma 4.1: Bound on the policy gradient variance
  • Lemma 4.2: Vector form of C-SoftTreeMax
  • Lemma 4.3: Gradient of C-SoftTreeMax
  • Theorem 4.4: Variance decay of C-SoftTreeMax
  • proof : Proof outline
  • Lemma 4.5: Vector form of E-SoftTreeMax
  • Lemma 4.6: Gradient of E-SoftTreeMax
  • Theorem 4.7: Variance decay of E-SoftTreeMax
  • Theorem 4.8
  • proof : Proof outline
  • ...and 20 more