Generative Modeling by Minimizing the Wasserstein-2 Loss
Yu-Jui Huang, Zachariah Malik
TL;DR
The paper develops a gradient-flow framework for unsupervised generative modeling by minimizing the Wasserstein-2 loss $J(\mu)=\tfrac{1}{2}W_2^2(\mu,\mu_d)$. It derives a distribution-dependent ODE driven by the Kantorovich potential, shows the time-marginal laws form a gradient flow that converges exponentially to the data distribution, and connects this to a nonlinear Fokker–Planck equation via Trevisan's superposition principle. An explicit time-change geodesic $\bm{\mu}^*_t$ is constructed, and a forward-Euler discretization (W2-FE) is proposed to simulate the gradient flow; the generator is trained with persistent minibatch updates, leading to faster convergence than WGAN baselines in both low and high dimensions. The method exhibits strong empirical performance on synthetic 2D tasks and a MNIST-to-USPS domain-adaptation task, and the paper clarifies when persistent training helps or hurts related GAN frameworks. Overall, the work provides a principled, scalable way to realize Wasserstein-2 gradient flows for generative modeling with practical algorithms and theoretical convergence guarantees.
Abstract
This paper approaches the unsupervised learning problem by minimizing the second-order Wasserstein loss (the $W_2$ loss) through a distribution-dependent ordinary differential equation (ODE), whose dynamics involves the Kantorovich potential associated with the true data distribution and a current estimate of it. A main result shows that the time-marginal laws of the ODE form a gradient flow for the $W_2$ loss, which converges exponentially to the true data distribution. An Euler scheme for the ODE is proposed and it is shown to recover the gradient flow for the $W_2$ loss in the limit. An algorithm is designed by following the scheme and applying persistent training, which naturally fits our gradient-flow approach. In both low- and high-dimensional experiments, our algorithm outperforms Wasserstein generative adversarial networks by increasing the level of persistent training appropriately.
