Variational Schrödinger Momentum Diffusion
Kevin Rojas, Yixin Tan, Molei Tao, Yuriy Nevmyvaka, Wei Deng
TL;DR
This paper addresses the scalability gap in transport-enabled diffusion models by introducing Variational Schrödinger Momentum Diffusion (VSMD), a simulation-free framework that uses linearized variational forward scores and a damping transform to stabilize training. By formulating VSMD as an adaptive multivariate diffusion with velocityAugmented states, the authors derive forward–backward SDEs, establish a stochastic-approximation-based algorithm to adaptively optimize transport, and prove convergence properties. Empirical results demonstrate efficient, anisotropic generation, fast convergence, and competitive performance on time-series forecasting and CIFAR-10 image generation, while avoiding warm-up trajectories and complex denoising. The approach offers scalable OT-enabled generation with reduced reliance on forward simulations, opening practical applicability to real-world data and multimodal generation tasks.
Abstract
The momentum Schrödinger Bridge (mSB) has emerged as a leading method for accelerating generative diffusion processes and reducing transport costs. However, the lack of simulation-free properties inevitably results in high training costs and affects scalability. To obtain a trade-off between transport properties and scalability, we introduce variational Schrödinger momentum diffusion (VSMD), which employs linearized forward score functions (variational scores) to eliminate the dependence on simulated forward trajectories. Our approach leverages a multivariate diffusion process with adaptively transport-optimized variational scores. Additionally, we apply a critical-damping transform to stabilize training by removing the need for score estimations for both velocity and samples. Theoretically, we prove the convergence of samples generated with optimal variational scores and momentum diffusion. Empirical results demonstrate that VSMD efficiently generates anisotropic shapes while maintaining transport efficacy, outperforming overdamped alternatives, and avoiding complex denoising processes. Our approach also scales effectively to real-world data, achieving competitive results in time series and image generation.
