Table of Contents
Fetching ...

Score-Based Generative Models Detect Manifolds

Jakiw Pidstrigach

TL;DR

The paper develops a theoretical framework for score-based generative models by analyzing forward and reverse SDEs and the effects of approximations to the initial distribution and the score drift. It derives conditions under which the generated samples share the same support as the data manifold, thereby addressing memorization versus genuine generalization. A key insight is that drift approximation must be unbounded to achieve generalization, and drift explosion near the terminal time is linked to the manifold structure of the data. The work also provides guidance on choosing priors for p_T and discusses broader implications for the theoretical understanding of SGMs.

Abstract

Score-based generative models (SGMs) need to approximate the scores $\nabla \log p_t$ of the intermediate distributions as well as the final distribution $p_T$ of the forward process. The theoretical underpinnings of the effects of these approximations are still lacking. We find precise conditions under which SGMs are able to produce samples from an underlying (low-dimensional) data manifold $\mathcal{M}$. This assures us that SGMs are able to generate the "right kind of samples". For example, taking $\mathcal{M}$ to be the subset of images of faces, we find conditions under which the SGM robustly produces an image of a face, even though the relative frequencies of these images might not accurately represent the true data generating distribution. Moreover, this analysis is a first step towards understanding the generalization properties of SGMs: Taking $\mathcal{M}$ to be the set of all training samples, our results provide a precise description of when the SGM memorizes its training data.

Score-Based Generative Models Detect Manifolds

TL;DR

The paper develops a theoretical framework for score-based generative models by analyzing forward and reverse SDEs and the effects of approximations to the initial distribution and the score drift. It derives conditions under which the generated samples share the same support as the data manifold, thereby addressing memorization versus genuine generalization. A key insight is that drift approximation must be unbounded to achieve generalization, and drift explosion near the terminal time is linked to the manifold structure of the data. The work also provides guidance on choosing priors for p_T and discusses broader implications for the theoretical understanding of SGMs.

Abstract

Score-based generative models (SGMs) need to approximate the scores of the intermediate distributions as well as the final distribution of the forward process. The theoretical underpinnings of the effects of these approximations are still lacking. We find precise conditions under which SGMs are able to produce samples from an underlying (low-dimensional) data manifold . This assures us that SGMs are able to generate the "right kind of samples". For example, taking to be the subset of images of faces, we find conditions under which the SGM robustly produces an image of a face, even though the relative frequencies of these images might not accurately represent the true data generating distribution. Moreover, this analysis is a first step towards understanding the generalization properties of SGMs: Taking to be the set of all training samples, our results provide a precise description of when the SGM memorizes its training data.
Paper Structure (30 sections, 10 theorems, 72 equations, 4 figures, 1 table)

This paper contains 30 sections, 10 theorems, 72 equations, 4 figures, 1 table.

Key Result

Corollary 1

Denote by $\hat{X}_t$ the forward SDE when started in the empirical measure $\pi_0 = \hat{\mu}_\text{data}$. Let $\int_0^T \| s_\theta(\hat{X}_t, t) - \nabla \log \hat{p}_t(\hat{X}_t) \mathrm{d}t$ be drift approximation error along a path of the forward SDE. For a given weighting function $w(t)$, th see Section sec:score_approximation. Simultaneously, if the exponential integral of the drift appro

Figures (4)

  • Figure 1: Top left: The leftmost plot shows the true data distribution $\mu_\text{data}$ which is a Gaussian mixture. The heat maps show the intermediate densities $p_t$ of $X_t$, followed by line plots of $p_t$ for $t=1$. Bottom left: The rightmost plot shows $\mu_\text{prior}$, which is a standard Gaussian and differs from $p_1$. We start the reverse SDE \ref{['reversesde']} in $\mu_\text{prior}$. But instead of using the real score, we introduce an approximation error and use $s(x,t) = \nabla \log p_{1-t}(x) + 3$ with a constant error of $3$. Again, the heat maps show how the densities $q_t$ of $Y_t$ evolve backwards in time. The leftmost plot shows the resulting distribution $q_1$ which is used as sample distribution, $\mu_\text{sample} = q_1$. Right: The densities $\mu_\text{data}$ and $\mu_\text{sample}$ are shown for direct comparison. We see that the approximation errors in $\mu_\text{prior}$ and the drift lead to an incorrect sample distribution $\mu_\text{sample} \not= \mu_\text{data}$. Nevertheless, $\mu_\text{sample}$ is supported in the same area as $\mu_\text{data}$. For details on the numerical implementation see Appendix \ref{['sec:numerics']}.
  • Figure 2: Perturbing $\nabla \log \hat{p}_t$.
  • Figure 3: (a): Both lines correspond to the same experiment for different drifts in the reverse SDE. For both lines we started $N=1000$ paths in the zero vector in $\mathbb{R}^{32\times32\times3}$. For the blue line we used the pretrained CIFAR-10 DDPM++ model from DBLP:conf/iclr/0011SKKEP21, whereas for the orange line we used the true drift $\nabla \log \hat{p}_t$, which is a mixture of $50 000$ Gaussians, one for each training example in CIFAR-10. We then saved the distance from $Y_t$ to the CIFAR-10 training examples, by calculating the distance to the closest example. Above we plot the average distance. We see, that while the reverse SDE run with $\hat{p}_t$ will have a distance of $0$ to the training examples in the end, the SDE with the pretrained drift keeps some distance to the training examples and therefore produces novel images. (b): We evaluate $\nabla \log \hat{p}_t$ as in (a). We do the analogous experiment to Figure \ref{['fig:sphere_samples_both_disturbed']} on CIFAR-10 and perturb the empirical drift $\nabla \log \hat{p}_t$ with a constant error vector. The first row shows the samples generated by adding the constant error vector $e(x, t) = (1, 1, \ldots, 1)\in \mathbb{R}^{32 \times 32 \times 3}$ to $\nabla \log \hat{p}_t$. In the second row we searched for the closest image in the CIFAR-10 dataset (with respect to the Euclidean $2$-distance on $\mathbb{R}^{32 \times 32 \times 3}$) and plotted it. We see that all the sampled images are nearly equal to a corresponding image in CIFAR-10. The distance of the images to their closest image in CIFAR-10 is around $0.07$ for all plotted images. Similar to Figure \ref{['fig:sphere_example']}, we can observe the effect of adding the one-vector. The sample distribution $\mu_\text{sample}$ got skewed to prefer images that have high pixel values. This corresponds to samples which are mostly white for the human eye. In the third and forth row we repeat the experiment of the first and second row, but add the negative one-vector $e(x,t) = (-1, -1, \ldots, -1)$ and get black images.
  • Figure 4: Left: We simulated the reverse SDE on CIFAR-10, once with the pretrained CIFAR-10 DDPM++ model $s_\theta$ from DBLP:conf/iclr/0011SKKEP21 and once with a perturbed drift $s(x,t) = \nabla \log \hat{p}_t + \frac{1}{2}(1,1, \ldots, 1)$. We then evaluated the integral \ref{['equ:novikov']} numerically for varying $t=T$. For the perturbed drift, the integral does not seem to explode as $t\to 1$, implying that $Z_t$ is a martingale. We see that for the DDPM++ drift, the integral explodes, therefore we can not infer that $Z_t$ is a martingale. We used $N=12 000$ simulations from both of the SDEs to generate this plot. Right: We again ran the two SDEs with the drifts as in the left Figure. This time, we measured the average distance to the empirical drift $\|s(\hat{Y}_t, t) - \nabla \log \hat{p}_t(\hat{Y_t})\|$ along a path of the reverse SDE. We repeated the experiment $N = 2560$ times and plotted the mean distance. For the constant perturbation we also of course get a constant distance. The distance of the true drift to $\nabla \log \hat{p}_t$ is initially very small but explodes as $t \to 1$. From our results we know that this explosion is necessary for the SGM to generalize.

Theorems & Definitions (21)

  • Corollary 1
  • Lemma 1
  • Theorem 1
  • Theorem 2
  • Lemma 2
  • Lemma 3
  • Definition 1
  • Definition 2
  • Lemma 4
  • proof
  • ...and 11 more