Table of Contents
Fetching ...

DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents

Yilun Xu, Gabriele Corso, Tommi Jaakkola, Arash Vahdat, Karsten Kreis

TL;DR

The paper addresses the challenge of learning multimodal data distributions with diffusion models by introducing Discrete-Continuous Latent Diffusion Models (DisCo-Diff) that couple a small set of learnable discrete latents with the standard continuous diffusion prior. It jointly learns a denoiser, an encoder for discrete latents, and an autoregressive latent prior in a two-stage, end-to-end framework, aided by Gumbel-Softmax relaxation and classifier-free guidance. Empirically, DisCo-Diff reduces ODE curvature, improves score matching, and delivers state-of-the-art or competitive results on class-conditioned ImageNet generation with ODE samplers and on molecular docking tasks, demonstrating cross-domain universality. The approach offers a practical, encoder-free conditioning mechanism that enhances fidelity while keeping computational overhead modest, suggesting broad applicability to diverse data modalities and diffusion-based generative models.

Abstract

Diffusion models (DMs) have revolutionized generative learning. They utilize a diffusion process to encode data into a simple Gaussian distribution. However, encoding a complex, potentially multimodal data distribution into a single continuous Gaussian distribution arguably represents an unnecessarily challenging learning problem. We propose Discrete-Continuous Latent Variable Diffusion Models (DisCo-Diff) to simplify this task by introducing complementary discrete latent variables. We augment DMs with learnable discrete latents, inferred with an encoder, and train DM and encoder end-to-end. DisCo-Diff does not rely on pre-trained networks, making the framework universally applicable. The discrete latents significantly simplify learning the DM's complex noise-to-data mapping by reducing the curvature of the DM's generative ODE. An additional autoregressive transformer models the distribution of the discrete latents, a simple step because DisCo-Diff requires only few discrete variables with small codebooks. We validate DisCo-Diff on toy data, several image synthesis tasks as well as molecular docking, and find that introducing discrete latents consistently improves model performance. For example, DisCo-Diff achieves state-of-the-art FID scores on class-conditioned ImageNet-64/128 datasets with ODE sampler.

DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents

TL;DR

The paper addresses the challenge of learning multimodal data distributions with diffusion models by introducing Discrete-Continuous Latent Diffusion Models (DisCo-Diff) that couple a small set of learnable discrete latents with the standard continuous diffusion prior. It jointly learns a denoiser, an encoder for discrete latents, and an autoregressive latent prior in a two-stage, end-to-end framework, aided by Gumbel-Softmax relaxation and classifier-free guidance. Empirically, DisCo-Diff reduces ODE curvature, improves score matching, and delivers state-of-the-art or competitive results on class-conditioned ImageNet generation with ODE samplers and on molecular docking tasks, demonstrating cross-domain universality. The approach offers a practical, encoder-free conditioning mechanism that enhances fidelity while keeping computational overhead modest, suggesting broad applicability to diverse data modalities and diffusion-based generative models.

Abstract

Diffusion models (DMs) have revolutionized generative learning. They utilize a diffusion process to encode data into a simple Gaussian distribution. However, encoding a complex, potentially multimodal data distribution into a single continuous Gaussian distribution arguably represents an unnecessarily challenging learning problem. We propose Discrete-Continuous Latent Variable Diffusion Models (DisCo-Diff) to simplify this task by introducing complementary discrete latent variables. We augment DMs with learnable discrete latents, inferred with an encoder, and train DM and encoder end-to-end. DisCo-Diff does not rely on pre-trained networks, making the framework universally applicable. The discrete latents significantly simplify learning the DM's complex noise-to-data mapping by reducing the curvature of the DM's generative ODE. An additional autoregressive transformer models the distribution of the discrete latents, a simple step because DisCo-Diff requires only few discrete variables with small codebooks. We validate DisCo-Diff on toy data, several image synthesis tasks as well as molecular docking, and find that introducing discrete latents consistently improves model performance. For example, DisCo-Diff achieves state-of-the-art FID scores on class-conditioned ImageNet-64/128 datasets with ODE sampler.
Paper Structure (37 sections, 6 equations, 17 figures, 5 tables, 3 algorithms)

This paper contains 37 sections, 6 equations, 17 figures, 5 tables, 3 algorithms.

Figures (17)

  • Figure 1: Discrete-Continuous Latent Variable Diffusion Models (DisCo-Diff) augment DMs with additional discrete latent variables that capture global appearance patterns, here shown for images of huskies. (a) During training, discrete latents are inferred through an encoder, for images a vision transformer dosovitskiy2021vit, and fed to the DM via cross-attention. Backpropagation is facilitated by continuous relaxation with a Gumbel-Softmax distribution. To sample novel images, an additional autoregressive model is learnt over the distribution of discrete latents. (b) Schematic visualization of generative denoising diffusion trajectories. Different colors indicate different discrete latent variables, pushing the trajectories toward different modes.
  • Figure 2: Samples generated from DisCo-Diff trained on the ImageNet dataset: (a) randomly sampled discrete latents and class labels; (b) samples in each grid sharing the same discrete latent. The class label for the top/bottom row is fixed to coffeepot/malamute.
  • Figure 3: Modeling 2D mixture of Gaussians.Left: Data distribution. Middle: Generated data by regular DM. Right: Generated data by DisCo-Diff. We use different colors to distinguish data generated by different discrete latents. We further provide zoom-ins and visualize some ODE trajectories by dotted lines.
  • Figure 4: Modeling 2D mixture of Gaussians: analysis. The mean curvature (left) and norm of the neural networks' Jacobians (right) along the reverse-time ODE trajectories as function of $t$.
  • Figure 5: Group hierarchical DisCo-Diff. Different discrete latents are fed to the denoiser U-Net at different feature resolutions.
  • ...and 12 more figures