Table of Contents
Fetching ...

Rao-Blackwell Gradient Estimators for Equivariant Denoising Diffusion

Vinh Tong, Hoang Trung-Dung, Anji Liu, Guy Van den Broeck, Mathias Niepert

TL;DR

This work interprets data augmentation as a Monte Carlo estimator of the training gradient and applies Rao-Blackwellization, leading to more stable optimization, faster convergence, and reduced variance, all while requiring only a single forward and backward pass per sample.

Abstract

In domains such as molecular and protein generation, physical systems exhibit inherent symmetries that are critical to model. Two main strategies have emerged for learning invariant distributions: designing equivariant network architectures and using data augmentation to approximate equivariance. While equivariant architectures preserve symmetry by design, they often involve greater complexity and pose optimization challenges. Data augmentation, on the other hand, offers flexibility but may fall short in fully capturing symmetries. Our framework enhances both approaches by reducing training variance and providing a provably lower-variance gradient estimator. We achieve this by interpreting data augmentation as a Monte Carlo estimator of the training gradient and applying Rao-Blackwellization. This leads to more stable optimization, faster convergence, and reduced variance, all while requiring only a single forward and backward pass per sample. We also present a practical implementation of this estimator incorporating the loss and sampling procedure through a method we call Orbit Diffusion. Theoretically, we guarantee that our loss admits equivariant minimizers. Empirically, Orbit Diffusion achieves state-of-the-art results on GEOM-QM9 for molecular conformation generation, improves crystal structure prediction, and advances text-guided crystal generation on the Perov-5 and MP-20 benchmarks. Additionally, it enhances protein designability in protein structure generation. Code is available at: https://github.com/vinhsuhi/Orbit-Diffusion.git.

Rao-Blackwell Gradient Estimators for Equivariant Denoising Diffusion

TL;DR

This work interprets data augmentation as a Monte Carlo estimator of the training gradient and applies Rao-Blackwellization, leading to more stable optimization, faster convergence, and reduced variance, all while requiring only a single forward and backward pass per sample.

Abstract

In domains such as molecular and protein generation, physical systems exhibit inherent symmetries that are critical to model. Two main strategies have emerged for learning invariant distributions: designing equivariant network architectures and using data augmentation to approximate equivariance. While equivariant architectures preserve symmetry by design, they often involve greater complexity and pose optimization challenges. Data augmentation, on the other hand, offers flexibility but may fall short in fully capturing symmetries. Our framework enhances both approaches by reducing training variance and providing a provably lower-variance gradient estimator. We achieve this by interpreting data augmentation as a Monte Carlo estimator of the training gradient and applying Rao-Blackwellization. This leads to more stable optimization, faster convergence, and reduced variance, all while requiring only a single forward and backward pass per sample. We also present a practical implementation of this estimator incorporating the loss and sampling procedure through a method we call Orbit Diffusion. Theoretically, we guarantee that our loss admits equivariant minimizers. Empirically, Orbit Diffusion achieves state-of-the-art results on GEOM-QM9 for molecular conformation generation, improves crystal structure prediction, and advances text-guided crystal generation on the Perov-5 and MP-20 benchmarks. Additionally, it enhances protein designability in protein structure generation. Code is available at: https://github.com/vinhsuhi/Orbit-Diffusion.git.

Paper Structure

This paper contains 58 sections, 9 theorems, 89 equations, 9 figures, 9 tables.

Key Result

Theorem 1

Let $\widehat{\nabla}_\phi$ be the Monte Carlo gradient from eqn:sym_gradient and $\widehat{\nabla}_\phi^{(RB)}$ be from eqn:new_sym_grad. If $\mathbb{E}[x_0 \mid x_t]$ can be computed exactly, then with strict inequality unless $x_0 \mid x_t$ is a Dirac delta.

Figures (9)

  • Figure 1: Gradient estimation strategies for training (approximately) equivariant diffusion models: (a) Sampling from the symmetrized joint distribution to obtain $x_0$ and $x_t$. (b) The standard data augmentation approach, which directly uses these samples for training. (c) The proposed method, leveraging self-normalizing importance sampling (SNIS) to estimate the inner conditional expectation. Both (b) and (c) require a single neural function evaluation per gradient step, but (c) has lower variance than (b). The pseudo-code for the Rao-Blackwell estimator with SNIS is shown on the right.
  • Figure 2: Learning curves.
  • Figure 3: Molecular conformer generation performance on GEOM-QM9. * Reported in the original paper. † Obtained using the published checkpoint. ‡ We train the public implementation from scratch.
  • Figure 4: Qualitative comparison of Crystal Structure Predictions by 9 models, including DiffCSP, TGDMat (S) and TGDmat (L) with baselines, OrbDiff_U, and OrbDiff_WN against ground-truth samples on randomly selected samples from Perov-5 and MP-20 dataset.
  • Figure 5: Protein Structure Generation. Full comparisons are in the appendix. "+ [finetune]" denotes $\mathcal{M}_{FS}^\text{no-tri}$ finetuned with the original loss; "+ [OrbDiff]" uses OrbDiff for finetuning. Full table can be found in \ref{['sub_app:PSG_full']}
  • ...and 4 more figures

Theorems & Definitions (16)

  • Theorem 1
  • Theorem 2
  • Theorem 1
  • proof
  • Proposition 1
  • Lemma 1
  • Lemma 2
  • proof
  • proof
  • Lemma 1
  • ...and 6 more