Table of Contents
Fetching ...

Discrete Adjoint Schrödinger Bridge Sampler

Wei Guo, Yuchen Zhu, Xiaochen Du, Juno Nam, Yongxin Chen, Rafael Gómez-Bombarelli, Guan-Horng Liu, Molei Tao, Jaemoo Choi

TL;DR

This work introduces Discrete Adjoint Schrödinger Bridge Sampler (DASBS), a principled framework that extends adjoint Schrödinger bridge ideas from continuous spaces to discrete CTMCs by leveraging a cyclic group structure on the state space. It formalizes discrete Schrödinger bridge and SOC theory for CTMCs, derives an explicit optimal transition-rate characterization, and develops a DASBS algorithm that alternates adjoint and corrector learning via controller and corrector matching (AM/DM) losses. A memoryless versus non-memoryless reference dynamics analysis clarifies when additive-noise-inspired objectives are most effective, and convergence guarantees are provided for the fixed-point learning dynamics. Empirically, DASBS achieves competitive sampling quality on lattice Ising and Potts models while offering notable training efficiency and scalability benefits, supported by rigorous ablations. The framework unifies existing memoryless SOC solvers under a discrete SB/SOC lens and points to future extensions to more complex discrete distributions and non-uniform reference dynamics.

Abstract

Learning discrete neural samplers is challenging due to the lack of gradients and combinatorial complexity. While stochastic optimal control (SOC) and Schrödinger bridge (SB) provide principled solutions, efficient SOC solvers like adjoint matching (AM), which excel in continuous domains, remain unexplored for discrete spaces. We bridge this gap by revealing that the core mechanism of AM is $\mathit{state}\text{-}\mathit{space~agnostic}$, and introduce $\mathbf{discrete~ASBS}$, a unified framework that extends AM and adjoint Schrödinger bridge sampler (ASBS) to discrete spaces. Theoretically, we analyze the optimality conditions of the discrete SB problem and its connection to SOC, identifying a necessary cyclic group structure on the state space to enable this extension. Empirically, discrete ASBS achieves competitive sample quality with significant advantages in training efficiency and scalability.

Discrete Adjoint Schrödinger Bridge Sampler

TL;DR

This work introduces Discrete Adjoint Schrödinger Bridge Sampler (DASBS), a principled framework that extends adjoint Schrödinger bridge ideas from continuous spaces to discrete CTMCs by leveraging a cyclic group structure on the state space. It formalizes discrete Schrödinger bridge and SOC theory for CTMCs, derives an explicit optimal transition-rate characterization, and develops a DASBS algorithm that alternates adjoint and corrector learning via controller and corrector matching (AM/DM) losses. A memoryless versus non-memoryless reference dynamics analysis clarifies when additive-noise-inspired objectives are most effective, and convergence guarantees are provided for the fixed-point learning dynamics. Empirically, DASBS achieves competitive sampling quality on lattice Ising and Potts models while offering notable training efficiency and scalability benefits, supported by rigorous ablations. The framework unifies existing memoryless SOC solvers under a discrete SB/SOC lens and points to future extensions to more complex discrete distributions and non-uniform reference dynamics.

Abstract

Learning discrete neural samplers is challenging due to the lack of gradients and combinatorial complexity. While stochastic optimal control (SOC) and Schrödinger bridge (SB) provide principled solutions, efficient SOC solvers like adjoint matching (AM), which excel in continuous domains, remain unexplored for discrete spaces. We bridge this gap by revealing that the core mechanism of AM is , and introduce , a unified framework that extends AM and adjoint Schrödinger bridge sampler (ASBS) to discrete spaces. Theoretically, we analyze the optimality conditions of the discrete SB problem and its connection to SOC, identifying a necessary cyclic group structure on the state space to enable this extension. Empirically, discrete ASBS achieves competitive sample quality with significant advantages in training efficiency and scalability.
Paper Structure (70 sections, 16 theorems, 87 equations, 3 figures, 3 tables, 1 algorithm)

This paper contains 70 sections, 16 theorems, 87 equations, 3 figures, 3 tables, 1 algorithm.

Key Result

Theorem 3.1

The optimal transition rate $u^\star$ for eq:sb_prob can be expressed as where the SB potentials$(\varphi_t, \widehat{\varphi}_t)$ satisfy and the optimal path measure $p^\star$ satisfies

Figures (3)

  • Figure 1: Ablation study of the adjoint matching (AM) and denoising matching (DM) training losses for the memoryless noise schedule $\gamma_t=\frac{1}{t}$ on Potts model with $L=8$, $N=3$, $\beta_\mathrm{high}=0.5$. Reweighting means using trajectory importance weight $p^\star/p^{\operatorname{sg}(u)}$ in the training losses. DM with reweighting corresponds to UDNS zhu2025mdns. Left: Energy Wasserstein-2 distance to the ground-truth samples from SW algorithm. Right: Effective sample size computed from the trajectory importance weights.
  • Figure 2: Ablation study of the hyperparameters $\alpha$ and $\gamma$ for the modified log-linear noise schedule $\gamma_t=\frac{\gamma}{t+\alpha}$ on Ising model with $L=24$ and $\beta_\mathrm{high}=0.28$. The case of $\alpha=0$ is memoryless. NFE is the number of function evaluations during generation for both training and inference. Top: fix $\gamma=1$ and vary $\alpha$. Bottom: fix $\alpha=1$ and vary $\gamma$. Left: average number of jumps for each dimension during generation. Right: 2-point correlation error and energy Wasserstein-2 distance to ground-truth samples drawn from SW algorithm.
  • Figure 3: Ablation study of the hyperparameter $\gamma$ for the constant noise schedule $\gamma_t\equiv\gamma$ on Ising model with $L=24$ and $\beta_\mathrm{high}=0.28$. NFE is the number of function evaluations during generation for both training and inference. Left: average number of jumps for each dimension during generation. Right: 2-point correlation error and energy Wasserstein-2 distance to ground-truth samples drawn from SW algorithm.

Theorems & Definitions (38)

  • Theorem 3.1
  • Theorem 3.2
  • proof : Sketch of proof
  • Proposition 5.1
  • Theorem 5.2
  • Proposition 2.1
  • Lemma 2.2
  • proof
  • Lemma 2.3
  • proof
  • ...and 28 more