Table of Contents
Fetching ...

DiffStitch: Boosting Offline Reinforcement Learning with Diffusion-based Trajectory Stitching

Guanghe Li, Yixiang Shan, Zhengbang Zhu, Ting Long, Weinan Zhang

TL;DR

Diffusion-based Trajectory Stitching (DiffStitch), a novel diffusion-based data augmentation pipeline that systematically generates stitching transitions between trajectories, effectively connects low-reward trajectories with high-reward trajectories, forming globally optimal trajectories to address the challenges faced by offline RL algorithms.

Abstract

In offline reinforcement learning (RL), the performance of the learned policy highly depends on the quality of offline datasets. However, in many cases, the offline dataset contains very limited optimal trajectories, which poses a challenge for offline RL algorithms as agents must acquire the ability to transit to high-reward regions. To address this issue, we introduce Diffusion-based Trajectory Stitching (DiffStitch), a novel diffusion-based data augmentation pipeline that systematically generates stitching transitions between trajectories. DiffStitch effectively connects low-reward trajectories with high-reward trajectories, forming globally optimal trajectories to address the challenges faced by offline RL algorithms. Empirical experiments conducted on D4RL datasets demonstrate the effectiveness of DiffStitch across RL methodologies. Notably, DiffStitch demonstrates substantial enhancements in the performance of one-step methods (IQL), imitation learning methods (TD3+BC), and trajectory optimization methods (DT).

DiffStitch: Boosting Offline Reinforcement Learning with Diffusion-based Trajectory Stitching

TL;DR

Diffusion-based Trajectory Stitching (DiffStitch), a novel diffusion-based data augmentation pipeline that systematically generates stitching transitions between trajectories, effectively connects low-reward trajectories with high-reward trajectories, forming globally optimal trajectories to address the challenges faced by offline RL algorithms.

Abstract

In offline reinforcement learning (RL), the performance of the learned policy highly depends on the quality of offline datasets. However, in many cases, the offline dataset contains very limited optimal trajectories, which poses a challenge for offline RL algorithms as agents must acquire the ability to transit to high-reward regions. To address this issue, we introduce Diffusion-based Trajectory Stitching (DiffStitch), a novel diffusion-based data augmentation pipeline that systematically generates stitching transitions between trajectories. DiffStitch effectively connects low-reward trajectories with high-reward trajectories, forming globally optimal trajectories to address the challenges faced by offline RL algorithms. Empirical experiments conducted on D4RL datasets demonstrate the effectiveness of DiffStitch across RL methodologies. Notably, DiffStitch demonstrates substantial enhancements in the performance of one-step methods (IQL), imitation learning methods (TD3+BC), and trajectory optimization methods (DT).
Paper Structure (27 sections, 13 equations, 5 figures, 4 tables, 1 algorithm)

This paper contains 27 sections, 13 equations, 5 figures, 4 tables, 1 algorithm.

Figures (5)

  • Figure 1: An illustration of trajectory stitching. Suppose there are two trajectories (blue and yellow) in the offline dataset, and the objective for the agent is to learn a policy that starts from $S$ and reaches $G$. (a) and (b) present the scenarios where the trajectories in the offline dataset intersect or are in close proximity, making it easier to learn a policy that leads to $G$. (c) presents the scenario where trajectories are far apart, posing a challenge for learning a viable policy. (d) illustrates previous solutions that generate trajectories based on original data to enhance policy learning. Although many branches extend from the original trajectory that starts at $S$, none of them formalizes a sample trajectory that starts from $S$ and reaches $G$. (e) illustrates our solution which stitches the trajectory starting from $S$ and ending at $G$, facilitating policy learning by providing a clear path to follow.
  • Figure 2: The overall pipeline framework of DiffStitch.
  • Figure 3: The stitching trajectory.
  • Figure 4: Return of states before and after stitching
  • Figure 5: Ablation study on the hopper