Table of Contents
Fetching ...

How Diffusion Models Learn to Factorize and Compose

Qiyao Liang, Ziming Liu, Mitchell Ostrow, Ila Fiete

TL;DR

This work considers a highly reduced setting to examine whether and when diffusion models learn semantically meaningful and factorized representations of composable features, and connects manifold formation in diffusion models to percolation theory in physics, offering insight into the sudden onset of factorized representation learning.

Abstract

Diffusion models are capable of generating photo-realistic images that combine elements which likely do not appear together in the training set, demonstrating the ability to \textit{compositionally generalize}. Nonetheless, the precise mechanism of compositionality and how it is acquired through training remains elusive. Inspired by cognitive neuroscientific approaches, we consider a highly reduced setting to examine whether and when diffusion models learn semantically meaningful and factorized representations of composable features. We performed extensive controlled experiments on conditional Denoising Diffusion Probabilistic Models (DDPMs) trained to generate various forms of 2D Gaussian bump images. We found that the models learn factorized but not fully continuous manifold representations for encoding continuous features of variation underlying the data. With such representations, models demonstrate superior feature compositionality but limited ability to interpolate over unseen values of a given feature. Our experimental results further demonstrate that diffusion models can attain compositionality with few compositional examples, suggesting a more efficient way to train DDPMs. Finally, we connect manifold formation in diffusion models to percolation theory in physics, offering insight into the sudden onset of factorized representation learning. Our thorough toy experiments thus contribute a deeper understanding of how diffusion models capture compositional structure in data.

How Diffusion Models Learn to Factorize and Compose

TL;DR

This work considers a highly reduced setting to examine whether and when diffusion models learn semantically meaningful and factorized representations of composable features, and connects manifold formation in diffusion models to percolation theory in physics, offering insight into the sudden onset of factorized representation learning.

Abstract

Diffusion models are capable of generating photo-realistic images that combine elements which likely do not appear together in the training set, demonstrating the ability to \textit{compositionally generalize}. Nonetheless, the precise mechanism of compositionality and how it is acquired through training remains elusive. Inspired by cognitive neuroscientific approaches, we consider a highly reduced setting to examine whether and when diffusion models learn semantically meaningful and factorized representations of composable features. We performed extensive controlled experiments on conditional Denoising Diffusion Probabilistic Models (DDPMs) trained to generate various forms of 2D Gaussian bump images. We found that the models learn factorized but not fully continuous manifold representations for encoding continuous features of variation underlying the data. With such representations, models demonstrate superior feature compositionality but limited ability to interpolate over unseen values of a given feature. Our experimental results further demonstrate that diffusion models can attain compositionality with few compositional examples, suggesting a more efficient way to train DDPMs. Finally, we connect manifold formation in diffusion models to percolation theory in physics, offering insight into the sudden onset of factorized representation learning. Our thorough toy experiments thus contribute a deeper understanding of how diffusion models capture compositional structure in data.
Paper Structure (26 sections, 3 equations, 15 figures)

This paper contains 26 sections, 3 equations, 15 figures.

Figures (15)

  • Figure 1: Example $32\times 32$ image data of a 2D Gaussian bump (left) and a 2D Gaussian SOS (right).
  • Figure 2: Metrics of a model trained using 2D Gaussian bump datasets with periodic boundaries.(a) 2D projections of a standard 3D torus (left) and a 4D Clifford torus (right). The 3D torus is an example of a coupled representation that can be learned by the model while the 4D torus is a factorized one. (b) Persistence diagrams of a standard torus (left, the diagram looks the same for Clifford tori) and the learned representation of the model at the terminal epoch (right). There are two overlapping orange points for $H_1$ in both diagrams. (c) Model accuracy (top) and effective dimension (bottom) of representation learned by the model as a function of training epochs. (d) PCA eigenspectrum (the first 15 dimensions) of the model's learned representations and their corresponding sample accuracy percentage and explained variance ratio of the top 4 PCs (labeled top right of each panel) at various checkpoints during training. (e)-(g) PCA visualizations of the learned representations at epoch 0, 150, and terminal epoch, respectively.
  • Figure 3: Comparison of orthogonality and parallelism test statistics between 3D torus, model's learned representation, and Clifford torus.$x$-on-$y$ (top row), $x$-on-$x$ (middle row), $y$-on-$y$ (bottom row) orthogonality (left column) and parallelism (right column) test statistics are compared between (a) an ideal 3D torus (blue), (b) the learned representation by the model (green), and (c) an ideal Clifford torus (orange).
  • Figure 4: Models trained on Gaussian SOS datasets to generalize to the test regions. We train three models on various Gaussian SOS datasets to test their ability to compositionally generalize in the red-shaded, held-out test regions shown in the sample image (f). (a) The 2D Gaussian SOS dataset contains all combination of 2D Gaussian SOSs for all $x$ and $y$ between 0 and 32 except for the held-out range between 13 and 19. (b) The 1D Gaussian stripe dataset contains horizontal and vertical 1D Gaussian stripes of full range of $x$ and $y$ values between 0 and 32. The accuracy of the three models in generating the correct $x$ and $y$ location of the Gaussian SOS is shown for different sections of the test regions: (c) The vertical section excluding the intersection, (d) the horizontal section excluding the intersection, and (e) the intersection. (f) Sample image of a 2D Gaussian SOS with the different test regions labeled. (g) shows the accuracy of models run with various subsampling rate of the 2D Gaussian bump + 1D Gaussian stripe dataset.
  • Figure 5: Sample efficiency gains from training the model on independent factors of variation.(a-b): Results on $N=32$ Gaussian 2D bump generation. (a) Model accuracy in generating 2D Gaussian bumps from training on 2D Gaussian bumps, shown as a function of the subsampling percentage. (b) Model accuracy in generating 2D Gaussian bumps from training on mixed 2D Gaussian bumps + 1D Gaussian stripes. Red dashed lines in (a),(b) mark a threshold accuracy of 0, 60, and 100%. (c) Log-log plot of dataset size needed to reach 60% threshold accuracy as a function of image size $N$ with 2D Gaussian bumps training data (blue) versus mixed 2D Gaussian bumps + 1D Gaussian stripe training data (orange): distinct scalings of data efficiency visualized by dashed gray and black lines, which provide a linear and quadratic reference, respectively.
  • ...and 10 more figures