Table of Contents
Fetching ...

Wasserstein proximal operators describe score-based generative models and resolve memorization

Benjamin J. Zhang, Siting Liu, Wuchen Li, Markos A. Katsoulakis, Stanley J. Osher

TL;DR

The paper addresses memorization and interpretability challenges in score-based generative models by recasting SGMs as Wasserstein proximal operators of the cross-entropy within a mean-field game framework. Through a Cole–Hopf transform, the backward HJB equation is connected to an uncontrolled forward diffusion, enabling a kernel-based score model that encodes the manifold structure of the data and reduces data requirements. A central contribution is the WPO-informed kernel model, which uses learned local precision matrices around kernel centers to produce a smooth, generalizing density and to enforce the terminal condition via implicit score matching, thereby mitigating memorization. The results provide a principled methodological bridge between optimal transport, PDEs, and manifold learning, offering faster training and scalable neural architectures for high-dimensional SGMs while yielding explicit density estimation and improved generalization.

Abstract

We focus on the fundamental mathematical structure of score-based generative models (SGMs). We first formulate SGMs in terms of the Wasserstein proximal operator (WPO) and demonstrate that, via mean-field games (MFGs), the WPO formulation reveals mathematical structure that describes the inductive bias of diffusion and score-based models. In particular, MFGs yield optimality conditions in the form of a pair of coupled partial differential equations: a forward-controlled Fokker-Planck (FP) equation, and a backward Hamilton-Jacobi-Bellman (HJB) equation. Via a Cole-Hopf transformation and taking advantage of the fact that the cross-entropy can be related to a linear functional of the density, we show that the HJB equation is an uncontrolled FP equation. Second, with the mathematical structure at hand, we present an interpretable kernel-based model for the score function which dramatically improves the performance of SGMs in terms of training samples and training time. In addition, the WPO-informed kernel model is explicitly constructed to avoid the recently studied memorization effects of score-based generative models. The mathematical form of the new kernel-based models in combination with the use of the terminal condition of the MFG reveals new explanations for the manifold learning and generalization properties of SGMs, and provides a resolution to their memorization effects. Finally, our mathematically informed, interpretable kernel-based model suggests new scalable bespoke neural network architectures for high-dimensional applications.

Wasserstein proximal operators describe score-based generative models and resolve memorization

TL;DR

The paper addresses memorization and interpretability challenges in score-based generative models by recasting SGMs as Wasserstein proximal operators of the cross-entropy within a mean-field game framework. Through a Cole–Hopf transform, the backward HJB equation is connected to an uncontrolled forward diffusion, enabling a kernel-based score model that encodes the manifold structure of the data and reduces data requirements. A central contribution is the WPO-informed kernel model, which uses learned local precision matrices around kernel centers to produce a smooth, generalizing density and to enforce the terminal condition via implicit score matching, thereby mitigating memorization. The results provide a principled methodological bridge between optimal transport, PDEs, and manifold learning, offering faster training and scalable neural architectures for high-dimensional SGMs while yielding explicit density estimation and improved generalization.

Abstract

We focus on the fundamental mathematical structure of score-based generative models (SGMs). We first formulate SGMs in terms of the Wasserstein proximal operator (WPO) and demonstrate that, via mean-field games (MFGs), the WPO formulation reveals mathematical structure that describes the inductive bias of diffusion and score-based models. In particular, MFGs yield optimality conditions in the form of a pair of coupled partial differential equations: a forward-controlled Fokker-Planck (FP) equation, and a backward Hamilton-Jacobi-Bellman (HJB) equation. Via a Cole-Hopf transformation and taking advantage of the fact that the cross-entropy can be related to a linear functional of the density, we show that the HJB equation is an uncontrolled FP equation. Second, with the mathematical structure at hand, we present an interpretable kernel-based model for the score function which dramatically improves the performance of SGMs in terms of training samples and training time. In addition, the WPO-informed kernel model is explicitly constructed to avoid the recently studied memorization effects of score-based generative models. The mathematical form of the new kernel-based models in combination with the use of the terminal condition of the MFG reveals new explanations for the manifold learning and generalization properties of SGMs, and provides a resolution to their memorization effects. Finally, our mathematically informed, interpretable kernel-based model suggests new scalable bespoke neural network architectures for high-dimensional applications.
Paper Structure (22 sections, 2 theorems, 40 equations, 6 figures, 1 algorithm)

This paper contains 22 sections, 2 theorems, 40 equations, 6 figures, 1 algorithm.

Key Result

Proposition 3.1

(Kernel representation formula for the score function). For initial condition $\eta(x,0) = \pi(x)$, the score function $\mathsf{s}(x,t) = \nabla \log \eta(x,t)$ in the denoising SDE eq:denoising has the kernel representation formula

Figures (6)

  • Figure 1: The core idea of this paper is that since the score function has a kernel representation formula, approximations to the score should respect the structure of the formula. Use of the empirical distribution to construct a kernel formula for the score function memorizes the training data. Our WPO-informed kernel model learns local precision matrices via the terminal condition of a HJB equation, which produces a kernel-based model that generalizes better and exhibits manifold learning.
  • Figure 2: 2500 samples generated via different models when training dataset size is limited to $5 \times 10^4$.
  • Figure 3: Density plots constructed by evaluating the kernel density estimated using WPO-informed kernel model as in experiments shows in Figure \ref{['fig:comparison_toydata']}. These density plots are not reconstructed densities from samples.
  • Figure 4: Six dimensional example: 3D swissroll noisily embedded in a 6D space. Proof of concept that the WPO-informed kernel model \ref{['eq:kernelcovdensity']} is at least scalable to moderate dimensions.
  • Figure 5: We use ellipses to represent the local covariance matrices obtained from the WPO-informed kernel model, which shows our trained model reveals the underlying data manifold. The model is trained with the two moons dataset, as the experiment depicted in Figure \ref{['fig:wpo_densiies']}. We draw $25$ samples $Z_i, i = 1, ..., 25$ from the trained kernel centers, accessing the learned covariance matrices by evaluating the trained Cholesky factorization $\boldsymbol{\Gamma}_{\theta}(x)^{-1}$ at $x = Z_i$ for $i = 1, ..., 25$. Each ellipse is centered at $Z_i$, and its orientation and axes are determined by the eigenvectors and eigenvalues of the corresponding covariance matrix. On the right, we present a zoomed-in plot of the left figure.
  • ...and 1 more figures

Theorems & Definitions (11)

  • Proposition 3.1
  • proof
  • Remark 3.1
  • Remark 3.2
  • Remark 3.3
  • Proposition 4.1
  • proof
  • Remark 4.1
  • Remark 4.2
  • Remark 4.3
  • ...and 1 more