Table of Contents
Fetching ...

Matryoshka Diffusion Models

Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly

TL;DR

MDM tackles the challenge of high-resolution diffusion by modeling a single diffusion process over an extended space that jointly denoises multiple resolutions. Central to the approach is the NestedUNet, which nests latent streams across resolutions and enables efficient shared computations, coupled with a progressive training schedule that adds higher resolutions gradually. The method achieves high-quality 1024x1024 text-to-image and 16x256x256 text-to-video results, trained on CC12M and WebVid-10M with strong zero-shot generalization, and shows faster convergence than cascaded or latent diffusion baselines. This end-to-end, multi-resolution framework offers a scalable, practical path to high-resolution generative modeling without resorting to multi-stage pipelines.

Abstract

Diffusion models are the de facto approach for generating high-quality images and videos, but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion Models(MDM), an end-to-end framework for high-resolution image and video synthesis. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small-scale inputs are nested within those of large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions, which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a single pixel-space model at resolutions of up to 1024x1024 pixels, demonstrating strong zero-shot generalization using the CC12M dataset, which contains only 12 million images. Our code is released at https://github.com/apple/ml-mdm

Matryoshka Diffusion Models

TL;DR

MDM tackles the challenge of high-resolution diffusion by modeling a single diffusion process over an extended space that jointly denoises multiple resolutions. Central to the approach is the NestedUNet, which nests latent streams across resolutions and enables efficient shared computations, coupled with a progressive training schedule that adds higher resolutions gradually. The method achieves high-quality 1024x1024 text-to-image and 16x256x256 text-to-video results, trained on CC12M and WebVid-10M with strong zero-shot generalization, and shows faster convergence than cascaded or latent diffusion baselines. This end-to-end, multi-resolution framework offers a scalable, practical path to high-resolution generative modeling without resorting to multi-stage pipelines.

Abstract

Diffusion models are the de facto approach for generating high-quality images and videos, but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion Models(MDM), an end-to-end framework for high-resolution image and video synthesis. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small-scale inputs are nested within those of large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions, which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a single pixel-space model at resolutions of up to 1024x1024 pixels, demonstrating strong zero-shot generalization using the CC12M dataset, which contains only 12 million images. Our code is released at https://github.com/apple/ml-mdm
Paper Structure (39 sections, 4 equations, 20 figures, 4 tables)

This paper contains 39 sections, 4 equations, 20 figures, 4 tables.

Figures (20)

  • Figure 1: ($\leftarrow\uparrow$) Images generated by MDM at $64^2$, $128^2$, $256^2$, $512^2$ and $1024^2$ resolutions using the prompt "a deer Matryoshka doll in Japanese kimono, super details, extreme realistic, 8k"; ($\leftarrow\downarrow$) $1$ and $16$ frames of $64^2$ video generated by our method using the prompt "pouring milk into black coffee"; All other samples are at $1024^2$ given various prompts. Images were resized for ease of visualization.
  • Figure 2: An illustration of Matryoshka Diffusion. $z^L_t$, $z^M_t$ and $z^H_t$ are noisy images at three different resolutions, which are fed into the denoising network together, and predict targets independently.
  • Figure 3: An illustration of the NestedUNet architecture used in Matryoshka Diffusion. We follow the design of podell2023sdxl by allocating more computation in the low resolution feature maps (by using more attention layers for example), where in the figure we use the width of a block to denote the parameter counts. Here the black arrows indicate connections inherited from UNet, and red arrows indicate additional connections introduced by Nested UNet.
  • Figure 4: Comparison against baselines during training. FID ($\downarrow$) (a, b) and CLIP($\uparrow$) (c) scores of samples generated without CFG during training of different class conditional models of ImageNet $256\times 256$ (a) and CC12M $256 \times 256$ (b, c). As can be seen, MDM models that were first trained at lower resolution (200K steps for ImageNet, and 390K for CC12M here) converge much faster.
  • Figure 5: Random samples from our class-conditional MDM trained on ImageNet $256\times 256$.
  • ...and 15 more figures