Table of Contents
Fetching ...

Enhancing Decision Transformer with Diffusion-Based Trajectory Branch Generation

Zhihong Liu, Long Qian, Zeyang Liu, Lipeng Wan, Xingyu Chen, Xuguang Lan

TL;DR

Diffusion-Based Trajectory Branch Generation (BG), which expands the trajectories of the dataset with branches generated by a diffusion model, and outperforms state-of-the-art sequence modeling methods on D4RL benchmark, demonstrating the effectiveness of adding branches to the dataset without further modifications.

Abstract

Decision Transformer (DT) can learn effective policy from offline datasets by converting the offline reinforcement learning (RL) into a supervised sequence modeling task, where the trajectory elements are generated auto-regressively conditioned on the return-to-go (RTG).However, the sequence modeling learning approach tends to learn policies that converge on the sub-optimal trajectories within the dataset, for lack of bridging data to move to better trajectories, even if the condition is set to the highest RTG.To address this issue, we introduce Diffusion-Based Trajectory Branch Generation (BG), which expands the trajectories of the dataset with branches generated by a diffusion model.The trajectory branch is generated based on the segment of the trajectory within the dataset, and leads to trajectories with higher returns.We concatenate the generated branch with the trajectory segment as an expansion of the trajectory.After expanding, DT has more opportunities to learn policies to move to better trajectories, preventing it from converging to the sub-optimal trajectories.Empirically, after processing with BG, DT outperforms state-of-the-art sequence modeling methods on D4RL benchmark, demonstrating the effectiveness of adding branches to the dataset without further modifications.

Enhancing Decision Transformer with Diffusion-Based Trajectory Branch Generation

TL;DR

Diffusion-Based Trajectory Branch Generation (BG), which expands the trajectories of the dataset with branches generated by a diffusion model, and outperforms state-of-the-art sequence modeling methods on D4RL benchmark, demonstrating the effectiveness of adding branches to the dataset without further modifications.

Abstract

Decision Transformer (DT) can learn effective policy from offline datasets by converting the offline reinforcement learning (RL) into a supervised sequence modeling task, where the trajectory elements are generated auto-regressively conditioned on the return-to-go (RTG).However, the sequence modeling learning approach tends to learn policies that converge on the sub-optimal trajectories within the dataset, for lack of bridging data to move to better trajectories, even if the condition is set to the highest RTG.To address this issue, we introduce Diffusion-Based Trajectory Branch Generation (BG), which expands the trajectories of the dataset with branches generated by a diffusion model.The trajectory branch is generated based on the segment of the trajectory within the dataset, and leads to trajectories with higher returns.We concatenate the generated branch with the trajectory segment as an expansion of the trajectory.After expanding, DT has more opportunities to learn policies to move to better trajectories, preventing it from converging to the sub-optimal trajectories.Empirically, after processing with BG, DT outperforms state-of-the-art sequence modeling methods on D4RL benchmark, demonstrating the effectiveness of adding branches to the dataset without further modifications.

Paper Structure

This paper contains 28 sections, 12 equations, 4 figures, 3 tables.

Figures (4)

  • Figure 1: A maze example to illustrate the problem of DT converging to sub-optimal trajectories and the importance of trajectory branches. Eval Trajectory refers to the trajectory generated by the policy during evaluation.
  • Figure 2: Overall pipeline of BG. The trajectory segments are randomly sampled from the trajectories of the dataset. The TVF generates $\tilde{g}_{t}$ based on the trajectory segment. The $\tilde{g}_{t}$ and the trajectory segment are combined together as the condition, then fed into the EDM diffusion model. The generated branches is concatenated to the trajectory segments as expansions of the trajectories in the dataset. Then DT is trained on the expanded dataset.
  • Figure 3: Branch demonstration on Maze2d-large map. The black pentagram represents the goal. The direction in which the color changes from light to dark is the direction of the trajectory.
  • Figure 4: Branches demonstration on Gym tasks. The length of the branches in Halfcheetah is 20, but we only show the branch from 5 to 15 to narrow the width.