Table of Contents
Fetching ...

Stabilizing Policy Gradients for Stochastic Differential Equations via Consistency with Perturbation Process

Xiangxin Zhou, Liang Wang, Yichi Zhou

TL;DR

The framework offers a general approach allowing for a versatile selection of policy gradient methods to effectively and efficiently train SDEs, and evaluates the algorithm on the task of structure-based drug design and optimize the binding affinity of generated ligand molecules.

Abstract

Considering generating samples with high rewards, we focus on optimizing deep neural networks parameterized stochastic differential equations (SDEs), the advanced generative models with high expressiveness, with policy gradient, the leading algorithm in reinforcement learning. Nevertheless, when applying policy gradients to SDEs, since the policy gradient is estimated on a finite set of trajectories, it can be ill-defined, and the policy behavior in data-scarce regions may be uncontrolled. This challenge compromises the stability of policy gradients and negatively impacts sample complexity. To address these issues, we propose constraining the SDE to be consistent with its associated perturbation process. Since the perturbation process covers the entire space and is easy to sample, we can mitigate the aforementioned problems. Our framework offers a general approach allowing for a versatile selection of policy gradient methods to effectively and efficiently train SDEs. We evaluate our algorithm on the task of structure-based drug design and optimize the binding affinity of generated ligand molecules. Our method achieves the best Vina score -9.07 on the CrossDocked2020 dataset.

Stabilizing Policy Gradients for Stochastic Differential Equations via Consistency with Perturbation Process

TL;DR

The framework offers a general approach allowing for a versatile selection of policy gradient methods to effectively and efficiently train SDEs, and evaluates the algorithm on the task of structure-based drug design and optimize the binding affinity of generated ligand molecules.

Abstract

Considering generating samples with high rewards, we focus on optimizing deep neural networks parameterized stochastic differential equations (SDEs), the advanced generative models with high expressiveness, with policy gradient, the leading algorithm in reinforcement learning. Nevertheless, when applying policy gradients to SDEs, since the policy gradient is estimated on a finite set of trajectories, it can be ill-defined, and the policy behavior in data-scarce regions may be uncontrolled. This challenge compromises the stability of policy gradients and negatively impacts sample complexity. To address these issues, we propose constraining the SDE to be consistent with its associated perturbation process. Since the perturbation process covers the entire space and is easy to sample, we can mitigate the aforementioned problems. Our framework offers a general approach allowing for a versatile selection of policy gradient methods to effectively and efficiently train SDEs. We evaluate our algorithm on the task of structure-based drug design and optimize the binding affinity of generated ligand molecules. Our method achieves the best Vina score -9.07 on the CrossDocked2020 dataset.
Paper Structure (32 sections, 5 theorems, 23 equations, 11 figures, 1 table, 2 algorithms)

This paper contains 32 sections, 5 theorems, 23 equations, 11 figures, 1 table, 2 algorithms.

Key Result

Lemma 4.2

If the SDE defined by $\epsilon_\theta$ is consistent, let $x_t \sim p_{t0}(x_t|x_0)$ where $x_0\sim q_0(x_0)$. Then, we have $x_t\sim q_t(x_t).$

Figures (11)

  • Figure 1: Illustration of our motivation and the advantages of our method. Top: Trajectories sampled by SDE-based policy. The terminal states (i.e., generated samples at $t=0$) are denoted as $x_0^1, x_0^2, x_0^3, x_0^4$. The reward function is denoted as $R(\cdot)$. A data-scarce region is marked with a red cross. Middle: Marginal distributions of consistent forward and backward SDEs over time $t$. Bottom: Trajectories perturbed from $x_0^1, x_0^2, x_0^3, x_0^4$ via the forward SDE. In vanilla SDE-based policy gradient methods (top), due to substantial expense of computation and time required by SDE simulation, the sampled trajectories and rewards are usually sparse. Therefore, the policy gradients in data-scarce regions are ill-defined, leading to instability. The consistency which can be ensured via score matching allows us to correctly estimate the policy gradients with sufficient data that can be efficiently sampled from the forward SDE (i.e., perturbation).
  • Figure 2: Comparison on the prediction error of policy gradients with respect to the number of trajectories under different settings of dimensionality. We evaluate $\mathbb{E}_{x_t}\lVert\nabla_{x_t} Q_\phi(x_t, \pi_\theta(x_t,t), t) - \nabla_{x_t}Q_{\phi^*}(x_t, \pi_\theta(x_t,t), t) \rVert$ where $\phi^*$ is trained on a large number of trajectories and $\phi$ is trained on a small number of trajectories. We can see the prediction error on policy gradient of our method is much lower than that of DDPG. Please refer to appendix for more details of this experiment.
  • Figure 3: The reward of $\pi_{\theta}(x_1, 1)$ in different region. The policy receives high reward in bright colored region. We can see that the policy only works well on region close to training set.
  • Figure 4: Optimization curves which show how average Vina Score of generated ligand molecules changes over optimization iterations.
  • Figure 5: Examples of generated ligands. Carbon atoms in ligand molecules by TargetDiff guan3d and DiffAC are visualized in green and cyan, respectively. Here we select some cases where TargetDiff easily generates unrealistic ligand molecules that clash with protein surfaces physically which usually leads to extremely bad Vina scores. DiffAC can sample realistic ligand molecules with high quality in these hard cases.
  • ...and 6 more figures

Theorems & Definitions (9)

  • Definition 4.1: Consistent SDE
  • Lemma 4.2
  • proof
  • Theorem 4.3
  • Lemma 4.4
  • proof
  • Theorem 4.5
  • Theorem 3.1
  • proof