Table of Contents
Fetching ...

Neural Sinkhorn Gradient Flow

Huminhao Zhu, Fangyikang Wang, Chao Zhang, Hanbin Zhao, Hui Qian

TL;DR

The paper addresses efficient generative modeling in the Wasserstein space by leveraging the Sinkhorn divergence and gradient flows. It introduces Neural Sinkhorn Gradient Flow (NSGF), which learns a time-varying velocity field via velocity-field matching using samples from the target distribution, and proves that the empirical mean-field limit converges to the true velocity field as the sample size grows. To tackle high-dimensional data, NSGF++ combines a short Sinkhorn-flow phase (≤5 NFEs) with a Neural Straight Flow refined by a phase-transition predictor, reducing computation while preserving quality. Empirical results on 2D simulations and image datasets (MNIST/CIFAR-10) show competitive generation quality with lower inference cost compared with neural diffusion and gradient-flow baselines, supporting the practical impact of this kernel-free, sample-efficient approach.

Abstract

Wasserstein Gradient Flows (WGF) with respect to specific functionals have been widely used in the machine learning literature. Recently, neural networks have been adopted to approximate certain intractable parts of the underlying Wasserstein gradient flow and result in efficient inference procedures. In this paper, we introduce the Neural Sinkhorn Gradient Flow (NSGF) model, which parametrizes the time-varying velocity field of the Wasserstein gradient flow w.r.t. the Sinkhorn divergence to the target distribution starting a given source distribution. We utilize the velocity field matching training scheme in NSGF, which only requires samples from the source and target distribution to compute an empirical velocity field approximation. Our theoretical analyses show that as the sample size increases to infinity, the mean-field limit of the empirical approximation converges to the true underlying velocity field. To further enhance model efficiency on high-dimensional tasks, a two-phase NSGF++ model is devised, which first follows the Sinkhorn flow to approach the image manifold quickly ($\le 5$ NFEs) and then refines the samples along a simple straight flow. Numerical experiments with synthetic and real-world benchmark datasets support our theoretical results and demonstrate the effectiveness of the proposed methods.

Neural Sinkhorn Gradient Flow

TL;DR

The paper addresses efficient generative modeling in the Wasserstein space by leveraging the Sinkhorn divergence and gradient flows. It introduces Neural Sinkhorn Gradient Flow (NSGF), which learns a time-varying velocity field via velocity-field matching using samples from the target distribution, and proves that the empirical mean-field limit converges to the true velocity field as the sample size grows. To tackle high-dimensional data, NSGF++ combines a short Sinkhorn-flow phase (≤5 NFEs) with a Neural Straight Flow refined by a phase-transition predictor, reducing computation while preserving quality. Empirical results on 2D simulations and image datasets (MNIST/CIFAR-10) show competitive generation quality with lower inference cost compared with neural diffusion and gradient-flow baselines, supporting the practical impact of this kernel-free, sample-efficient approach.

Abstract

Wasserstein Gradient Flows (WGF) with respect to specific functionals have been widely used in the machine learning literature. Recently, neural networks have been adopted to approximate certain intractable parts of the underlying Wasserstein gradient flow and result in efficient inference procedures. In this paper, we introduce the Neural Sinkhorn Gradient Flow (NSGF) model, which parametrizes the time-varying velocity field of the Wasserstein gradient flow w.r.t. the Sinkhorn divergence to the target distribution starting a given source distribution. We utilize the velocity field matching training scheme in NSGF, which only requires samples from the source and target distribution to compute an empirical velocity field approximation. Our theoretical analyses show that as the sample size increases to infinity, the mean-field limit of the empirical approximation converges to the true underlying velocity field. To further enhance model efficiency on high-dimensional tasks, a two-phase NSGF++ model is devised, which first follows the Sinkhorn flow to approach the image manifold quickly ( NFEs) and then refines the samples along a simple straight flow. Numerical experiments with synthetic and real-world benchmark datasets support our theoretical results and demonstrate the effectiveness of the proposed methods.
Paper Structure (27 sections, 8 theorems, 57 equations, 11 figures, 3 tables, 3 algorithms)

This paper contains 27 sections, 8 theorems, 57 equations, 11 figures, 3 tables, 3 algorithms.

Key Result

Lemma 1

(Optimality cuturi2013sinkhorn) The ${\mathcal{W}}_{\varepsilon}$-potentials $(f_{\mu, \nu}, g_{\mu, \nu})$ exist and are unique $(\mu, \nu)-a.e.$ up to an additive constant (i.e. $\forall K \in {\mathbb{R}}, (f_{\mu, \nu} + K, g_{\mu, \nu} - K)$ is optimal). Moreover,

Figures (11)

  • Figure 1: Tajectories comparison between the Flow matching and the NSGF++ model in CIFAR-10 task. we can see NSGF++ model quickly recovers the target structure and progressively optimizes the details in subsequent steps
  • Figure 2: NSGF++ framework
  • Figure 3: Visualization results for 2D generated paths. We show different methods that drive the particle from the prior distribution (black) to the target distribution (blue). The color change of the flow shows the different number of steps (from blue to red means from $0$ to $T$).
  • Figure 4: 2-Wasserstein Distance of the generated process utilizing neural ODE-based diffusion models and NSGF. The FM/SI methods reduce noise roughly linearly, while NSGF quickly recovers the target structure and progressively optimizes the details in subsequent steps.
  • Figure 5: The inference result of our NSGF++ model. The first row shows the result after 5 NSGF steps and the second row shows the final results.
  • ...and 6 more figures

Theorems & Definitions (16)

  • Definition 1
  • Lemma 1
  • Definition 2
  • Theorem 1
  • Proposition 1
  • Remark 1
  • Theorem 2
  • Proposition 2
  • Lemma 2
  • Definition 3
  • ...and 6 more