Table of Contents
Fetching ...

MDNS: Masked Diffusion Neural Sampler via Stochastic Optimal Control

Yuchen Zhu, Wei Guo, Jaemoo Choi, Guan-Horng Liu, Yongxin Chen, Molei Tao

TL;DR

The efficiency and scalability of MDNS are validated through extensive experiments on various distributions with distinct statistical properties, where MDNS learns to accurately sample from the target distributions despite the extremely high problem dimensions and outperforms other learning-based baselines by a large margin.

Abstract

We study the problem of learning a neural sampler to generate samples from discrete state spaces where the target probability mass function $π\propto\mathrm{e}^{-U}$ is known up to a normalizing constant, which is an important task in fields such as statistical physics, machine learning, combinatorial optimization, etc. To better address this challenging task when the state space has a large cardinality and the distribution is multi-modal, we propose $\textbf{M}$asked $\textbf{D}$iffusion $\textbf{N}$eural $\textbf{S}$ampler ($\textbf{MDNS}$), a novel framework for training discrete neural samplers by aligning two path measures through a family of learning objectives, theoretically grounded in the stochastic optimal control of the continuous-time Markov chains. We validate the efficiency and scalability of MDNS through extensive experiments on various distributions with distinct statistical properties, where MDNS learns to accurately sample from the target distributions despite the extremely high problem dimensions and outperforms other learning-based baselines by a large margin. A comprehensive study of ablations and extensions is also provided to demonstrate the efficacy and potential of the proposed framework. Our code is available at https://github.com/yuchen-zhu-zyc/MDNS.

MDNS: Masked Diffusion Neural Sampler via Stochastic Optimal Control

TL;DR

The efficiency and scalability of MDNS are validated through extensive experiments on various distributions with distinct statistical properties, where MDNS learns to accurately sample from the target distributions despite the extremely high problem dimensions and outperforms other learning-based baselines by a large margin.

Abstract

We study the problem of learning a neural sampler to generate samples from discrete state spaces where the target probability mass function is known up to a normalizing constant, which is an important task in fields such as statistical physics, machine learning, combinatorial optimization, etc. To better address this challenging task when the state space has a large cardinality and the distribution is multi-modal, we propose asked iffusion eural ampler (), a novel framework for training discrete neural samplers by aligning two path measures through a family of learning objectives, theoretically grounded in the stochastic optimal control of the continuous-time Markov chains. We validate the efficiency and scalability of MDNS through extensive experiments on various distributions with distinct statistical properties, where MDNS learns to accurately sample from the target distributions despite the extremely high problem dimensions and outperforms other learning-based baselines by a large margin. A comprehensive study of ablations and extensions is also provided to demonstrate the efficacy and potential of the proposed framework. Our code is available at https://github.com/yuchen-zhu-zyc/MDNS.

Paper Structure

This paper contains 63 sections, 12 theorems, 80 equations, 18 figures, 6 tables, 4 algorithms.

Key Result

Lemma 1

Given two CTMCs with generators $Q^1,Q^2$ and initial distributions $\mu_1,\mu_2$ on $\mathcal{X}$, let $\mathbb{P}^1,\mathbb{P}^2$ be the associated path measures. Then for any trajectory $\xi=(\xi_t)_{t\in[0,T]}$,

Figures (18)

  • Figure 1: Average of 2-point correlation $C^{\mathrm{row}}(k, k + r)$ of samples from $16\times16$ Ising model.
  • Figure 2: Average of 2-point correlation $C^{\mathrm{row}}(k, k+ r)$ of samples from $16\times16$ Potts model.
  • Figure 3: Visualization of learning performance across the number of replicates $R$ for learning $4\times4$ Ising model with $J=1$ and $h=0.1$ using the WDCE loss. The metrics reported are $\operatorname{TV}(\widehat{p}_{\mathrm{samp}},\pi)$, $\operatorname{KL}(\widehat{p}_{\mathrm{samp}}\|\pi)$, and $\chi^2(\widehat{p}_{\mathrm{samp}}\|\pi)$.
  • Figure 4: Visualization of non-cherry-picked samples from the learned $16\times16$ Ising model with $J=1$, $h=0$, and $\beta_\mathrm{low}=0.6$. (a) MDNS. (b) LEAPS. (c) Ground Truth (simulated with SW algorithm).
  • Figure 5: Visualization of non-cherry-picked samples from the learned $16\times16$ Ising model with $J=1$, $h=0$, and $\beta_\mathrm{critical}=0.4407$. (a) MDNS. (b) LEAPS. (c) Ground Truth (simulated with SW algorithm).
  • ...and 13 more figures

Theorems & Definitions (26)

  • Lemma 1
  • Lemma 2
  • Lemma 3
  • Proposition 1: Guarantee for sampling
  • Proposition 2: Guarantee for normalizing constant estimation
  • Lemma 4
  • proof
  • Lemma 5
  • proof
  • proof
  • ...and 16 more