Table of Contents
Fetching ...

CountsDiff: A Diffusion Model on the Natural Numbers for Generation and Imputation of Count-Based Data

Renzo G. Soatto, Anders Hoel, Greycen Ren, Shorna Alam, Stephen Bates, Nikolaos P. Daskalakis, Caroline Uhler, Maria Skoularidou

Abstract

Diffusion models have excelled at generative tasks for both continuous and token-based domains, but their application to discrete ordinal data remains underdeveloped. We present CountsDiff, a diffusion framework designed to natively model distributions on the natural numbers. CountsDiff extends the Blackout diffusion framework by simplifying its formulation through a direct parameterization in terms of a survival probability schedule and an explicit loss weighting. This introduces flexibility through design parameters with direct analogues in existing diffusion modeling frameworks. Beyond this reparameterization, CountsDiff introduces features from modern diffusion models, previously absent in counts-based domains, including continuous-time training, classifier-free guidance, and churn/remasking reverse dynamics that allow non-monotone reverse trajectories. We propose an initial instantiation of CountsDiff and validate it on natural image datasets (CIFAR-10, CelebA), exploring the effects of varying the introduced design parameters in a complex, well-studied, and interpretable data domain. We then highlight biological count assays as a natural use case, evaluating CountsDiff on single-cell RNA-seq imputation in a fetal cell and heart cell atlas. Remarkably, we find that even this simple instantiation matches or surpasses the performance of a state-of-the-art discrete generative model and leading RNA-seq imputation methods, while leaving substantial headroom for further gains through optimized design choices in future work.

CountsDiff: A Diffusion Model on the Natural Numbers for Generation and Imputation of Count-Based Data

Abstract

Diffusion models have excelled at generative tasks for both continuous and token-based domains, but their application to discrete ordinal data remains underdeveloped. We present CountsDiff, a diffusion framework designed to natively model distributions on the natural numbers. CountsDiff extends the Blackout diffusion framework by simplifying its formulation through a direct parameterization in terms of a survival probability schedule and an explicit loss weighting. This introduces flexibility through design parameters with direct analogues in existing diffusion modeling frameworks. Beyond this reparameterization, CountsDiff introduces features from modern diffusion models, previously absent in counts-based domains, including continuous-time training, classifier-free guidance, and churn/remasking reverse dynamics that allow non-monotone reverse trajectories. We propose an initial instantiation of CountsDiff and validate it on natural image datasets (CIFAR-10, CelebA), exploring the effects of varying the introduced design parameters in a complex, well-studied, and interpretable data domain. We then highlight biological count assays as a natural use case, evaluating CountsDiff on single-cell RNA-seq imputation in a fetal cell and heart cell atlas. Remarkably, we find that even this simple instantiation matches or surpasses the performance of a state-of-the-art discrete generative model and leading RNA-seq imputation methods, while leaving substantial headroom for further gains through optimized design choices in future work.

Paper Structure

This paper contains 68 sections, 3 theorems, 62 equations, 11 figures, 15 tables, 2 algorithms.

Key Result

Proposition 3.1

Given $p:[0,1] \rightarrow[0,1]$ differentiable, monotonically decreasing, and with endpoints $p(0) = 1$, $p(1) = 0$, there exists a CountsDiff forward process with $p$-schedule $p(t)$. $\blacktriangleleft$$\blacktriangleleft$

Figures (11)

  • Figure 1: Visualization of CountsDiff's forward corruption process (top) and reverse sampling process (bottom). The top diagram depicts the progression of a $p$-schedule, a pure death process. The bottom shows a single step of the generalized, birth-death sampling process.
  • Figure 2: Histogram of model-generated samples versus ground truth and distributional distance metrics (top); variance statistics (bottom) for a subset of dimensions. Existing diffusion models exhibit failure cases even in a simple toy dataset: Gaussian diffusion suffers from mode collapse, while masked diffusion overfits outliers (inflated variance). Full results for all ten dimensions can be found in Appendix \ref{['appendix:simulated_experiments']}.
  • Figure 3: 5 images guided by each Cifar10 class sampled from CountsDiff with guidance scale $2.0$ and $\eta_{\text{rescale}} = 0.005$.
  • Figure 4: Nine images drawn from CountsDiff trained on CIFAR-10 with $\eta_\text{rescale}$ attrition schedule for varying levels of $\eta_\text{rescale}$
  • Figure 5: Converted $p$-schedule from Blackout Diffusion (see \ref{['appendix: converting blackout noise schedule']}) versus cosine $p$-schedule
  • ...and 6 more figures

Theorems & Definitions (6)

  • Proposition 3.1
  • Proposition 3.2: Reverse step with attrition
  • proof
  • Proposition 2.1
  • proof
  • proof