Table of Contents
Fetching ...

DiffCPS: Diffusion Model based Constrained Policy Search for Offline Reinforcement Learning

Longxiang He, Li Shen, Linrui Zhang, Junbo Tan, Xueqian Wang

TL;DR

The theoretical analysis reveals that strong duality holds for diffusion-based CPS problems, and upon introducing parameter approximation, an approximated solution can be obtained after $\mathcal{O}(1/\epsilon)$ number of dual iterations, where $\epsilon$ denotes the representation ability of the parametrized policy.

Abstract

Constrained policy search (CPS) is a fundamental problem in offline reinforcement learning, which is generally solved by advantage weighted regression (AWR). However, previous methods may still encounter out-of-distribution actions due to the limited expressivity of Gaussian-based policies. On the other hand, directly applying the state-of-the-art models with distribution expression capabilities (i.e., diffusion models) in the AWR framework is intractable since AWR requires exact policy probability densities, which is intractable in diffusion models. In this paper, we propose a novel approach, $\textbf{Diffusion-based Constrained Policy Search}$ (dubbed DiffCPS), which tackles the diffusion-based constrained policy search with the primal-dual method. The theoretical analysis reveals that strong duality holds for diffusion-based CPS problems, and upon introducing parameter approximation, an approximated solution can be obtained after $\mathcal{O}(1/ε)$ number of dual iterations, where $ε$ denotes the representation ability of the parametrized policy. Extensive experimental results based on the D4RL benchmark demonstrate the efficacy of our approach. We empirically show that DiffCPS achieves better or at least competitive performance compared to traditional AWR-based baselines as well as recent diffusion-based offline RL methods. The code is now available at https://github.com/felix-thu/DiffCPS.

DiffCPS: Diffusion Model based Constrained Policy Search for Offline Reinforcement Learning

TL;DR

The theoretical analysis reveals that strong duality holds for diffusion-based CPS problems, and upon introducing parameter approximation, an approximated solution can be obtained after number of dual iterations, where denotes the representation ability of the parametrized policy.

Abstract

Constrained policy search (CPS) is a fundamental problem in offline reinforcement learning, which is generally solved by advantage weighted regression (AWR). However, previous methods may still encounter out-of-distribution actions due to the limited expressivity of Gaussian-based policies. On the other hand, directly applying the state-of-the-art models with distribution expression capabilities (i.e., diffusion models) in the AWR framework is intractable since AWR requires exact policy probability densities, which is intractable in diffusion models. In this paper, we propose a novel approach, (dubbed DiffCPS), which tackles the diffusion-based constrained policy search with the primal-dual method. The theoretical analysis reveals that strong duality holds for diffusion-based CPS problems, and upon introducing parameter approximation, an approximated solution can be obtained after number of dual iterations, where denotes the representation ability of the parametrized policy. Extensive experimental results based on the D4RL benchmark demonstrate the efficacy of our approach. We empirically show that DiffCPS achieves better or at least competitive performance compared to traditional AWR-based baselines as well as recent diffusion-based offline RL methods. The code is now available at https://github.com/felix-thu/DiffCPS.
Paper Structure (27 sections, 9 theorems, 61 equations, 9 figures, 6 tables, 1 algorithm)

This paper contains 27 sections, 9 theorems, 61 equations, 9 figures, 6 tables, 1 algorithm.

Key Result

Theorem 3.1

Let $\mu({\bm{a}}\vert{\bm{s}})$ be a diffusion-based policy and $\pi_b$ be the behavior policy. Then, we have

Figures (9)

  • Figure 1: Toy offline experiment on a simple bandit task. We test the performance of AWR and other diffusion-based offline RL algorithms (DQL wang2022b and SfBC chen2022). The first row displays the actions taken by the trained policy where $T$ denotes diffusion steps. We note that the AWR fails to capture the multi-modal actions in the offline dataset due to the limited policy expressivity of unimodal Gaussian. The second row shows the effect of different diffusion steps $T$.
  • Figure 2: Ablation studies of diffusion steps $T$ on selected Gym tasks (three random seeds). We observe that as $T$ increases, the training stability improves, but the final performance drops.
  • Figure 3: Ablation studies of $\lambda_{\text{clip}}$ in AntMaze and MuJoCo tasks. We observe that $\lambda_{\text{clip}}$ has little impact on MuJoCo tasks but significantly influences AntMaze tasks, especially as AntMaze datasets are larger. The reason is that the sparse rewards and suboptimal trajectories in AntMaze datasets make the critic network prone to error estimation, leading to learning poor policy. Therefore, there is a need to enhance learning from the original dataset which means we should increase $\lambda$ or enhance the KL constraint. We find that increasing $\lambda_{\text{clip}}$ while maintaining a moderate KL constraint achieves the best results. All the results are obtained by evaluating three random seeds.
  • Figure 4: Ablation studies of the policy evaluation interval in AntMaze and MuJoCo tasks. Delayed policy updates have a relatively minor impact on the MuJoCo locomotion tasks. However, for large-scale sparse reward datasets like AntMaze Large, choosing an appropriate update frequency can greatly increase the final optimal results. The MuJoCo task results are obtained with 2 million training steps (three random seeds), while AntMaze results are obtained with 1 million training steps (three random seeds).
  • Figure 5: Evaluation performance of DiffCPS and other baselines on toy bandit experiments. The dashed line represents the score of AWR. We also observe that as $T$ increases, diffusion-based algorithms all experience a certain degree of performance decline, especially SfBC. The reason could be that as $T$ increases, the increased model capacity leads to overfitting the data in the dataset. In the case of SfBC, the presence of sampling errors exacerbates this phenomenon.
  • ...and 4 more figures

Theorems & Definitions (20)

  • Theorem 3.1
  • Corollary 3.2
  • Remark 3.3
  • Remark 3.6
  • Theorem 3.7
  • Proposition 3.8
  • Remark 3.9
  • Definition 3.10
  • Theorem 3.11
  • Remark 3.12
  • ...and 10 more