Table of Contents
Fetching ...

Scalable Wasserstein Gradient Flow for Generative Modeling through Unbalanced Optimal Transport

Jaemoo Choi, Jaewoong Choi, Myungjoo Kang

TL;DR

This paper introduces a scalable WGF-based generative model, called Semi-dual JKO (S-JKO), based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport.

Abstract

Wasserstein Gradient Flow (WGF) describes the gradient dynamics of probability density within the Wasserstein space. WGF provides a promising approach for conducting optimization over the probability distributions. Numerically approximating the continuous WGF requires the time discretization method. The most well-known method for this is the JKO scheme. In this regard, previous WGF models employ the JKO scheme and parametrize transport map for each JKO step. However, this approach results in quadratic training complexity $O(K^2)$ with the number of JKO step $K$. This severely limits the scalability of WGF models. In this paper, we introduce a scalable WGF-based generative model, called Semi-dual JKO (S-JKO). Our model is based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport. Our approach reduces the training complexity to $O(K)$. We demonstrate that our model significantly outperforms existing WGF-based generative models, achieving FID scores of 2.62 on CIFAR-10 and 5.46 on CelebA-HQ-256, which are comparable to state-of-the-art image generative models.

Scalable Wasserstein Gradient Flow for Generative Modeling through Unbalanced Optimal Transport

TL;DR

This paper introduces a scalable WGF-based generative model, called Semi-dual JKO (S-JKO), based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport.

Abstract

Wasserstein Gradient Flow (WGF) describes the gradient dynamics of probability density within the Wasserstein space. WGF provides a promising approach for conducting optimization over the probability distributions. Numerically approximating the continuous WGF requires the time discretization method. The most well-known method for this is the JKO scheme. In this regard, previous WGF models employ the JKO scheme and parametrize transport map for each JKO step. However, this approach results in quadratic training complexity with the number of JKO step . This severely limits the scalability of WGF models. In this paper, we introduce a scalable WGF-based generative model, called Semi-dual JKO (S-JKO). Our model is based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport. Our approach reduces the training complexity to . We demonstrate that our model significantly outperforms existing WGF-based generative models, achieving FID scores of 2.62 on CIFAR-10 and 5.46 on CelebA-HQ-256, which are comparable to state-of-the-art image generative models.
Paper Structure (45 sections, 1 theorem, 35 equations, 14 figures, 5 tables, 3 algorithms)

This paper contains 45 sections, 1 theorem, 35 equations, 14 figures, 5 tables, 3 algorithms.

Key Result

Lemma 1.1

uot1semi-dual1semi-dual3uotm Consider the following optimization problem: Then, the semi-dual formulation of Eq eq:uot2 is given as where $v^c(x) = \inf_y \left[ c\left(x, y\right) - v(y) \right]$. Moreover, the strong duality holds.

Figures (14)

  • Figure 1: (a) Visualization of the Training Process of Existing JKO Models. For each training iteration, sampling from $\mu_k$ involves sequential inference through $k$-networks, i.e., $T_k \circ \dots \circ T_0(x)$ with $x\sim \mu$. This iterative network evaluation considerably slows down the training process. Formally, the training complexity becomes $O(K^2)$ where $K$ denotes the number of JKO steps. (b) Two Variants of the UOTMs.Left: Source-fixed UOTM. Right: Both-relaxed-UOTM. For brevity, we simply call the Both-relaxed-UOTM as UOTM. UOTMs allow flexibility in marginal densities and therefore have inherent distribution errors (Blue: Source and Target distributions $\mu, \nu$. Orange: Marginal distributions of Optimal Coupling $\pi_0, \pi_1$.)
  • Figure 2: Conceptual Diagram of Our Model. During the training $k$-th JKO step in our model, sampling from $\mu_k$ involves only one network inference $T_{k-1}$, i.e., $\mu_k = {T_{k-1}}_\# \mu_0$. This reparametrization strategy significantly reduces the overall training time. Formally, the training time complexity reduces to $O(K)$ from the $O(K^2)$ of other JKO models. Moreover, by initializing the parameters of $T_k$ with $T_{k-1}$, we can further decrease the number of iterations required for training.
  • Figure 2: Image Generation on CIFAR-10.$\dagger$ indicates the results conducted by ourselves.
  • Figure 2: Extensive Comparison with Diverse Generative Models on Image Generation on CIFAR-10.$\dagger$ indicates the results conducted by ourselves.
  • Figure 3: Generation Results for UOTM, Source-fixed UOTM, and S-JKO on Synthetic Datasets. Each column shows the generated distribution at training iterations {5, 20, 40, 60, 100}K. $k$ denotes the index of JKO steps corresponding to that particular iteration.
  • ...and 9 more figures

Theorems & Definitions (2)

  • Lemma 1.1
  • proof