Matryoshka Diffusion Models
Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly
TL;DR
MDM tackles the challenge of high-resolution diffusion by modeling a single diffusion process over an extended space that jointly denoises multiple resolutions. Central to the approach is the NestedUNet, which nests latent streams across resolutions and enables efficient shared computations, coupled with a progressive training schedule that adds higher resolutions gradually. The method achieves high-quality 1024x1024 text-to-image and 16x256x256 text-to-video results, trained on CC12M and WebVid-10M with strong zero-shot generalization, and shows faster convergence than cascaded or latent diffusion baselines. This end-to-end, multi-resolution framework offers a scalable, practical path to high-resolution generative modeling without resorting to multi-stage pipelines.
Abstract
Diffusion models are the de facto approach for generating high-quality images and videos, but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion Models(MDM), an end-to-end framework for high-resolution image and video synthesis. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small-scale inputs are nested within those of large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions, which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a single pixel-space model at resolutions of up to 1024x1024 pixels, demonstrating strong zero-shot generalization using the CC12M dataset, which contains only 12 million images. Our code is released at https://github.com/apple/ml-mdm
