Table of Contents
Fetching ...

Disentangled representations via score-based variational autoencoders

Benjamin S. H. Lyo, Eero P. Simoncelli, Cristina Savin

TL;DR

Disentangled representations are learned without supervision by integrating diffusion-based generative modeling with variational autoencoders in SAMI. The key idea is to use conditional diffusion as the VAE's generator and to derive an exact ELBO that permits reusing the inference network to guide diffusion; this leads to latent spaces with semantically meaningful axes and smoother trajectories for video. Empirically, SAMI extracts factorized latent factors from synthetic disks and CelebA, achieves competitive sample quality, and can obtain semantic axes from pretrained diffusion models. The work provides theoretical insight into how diffusion priors induce structure in latent representations and suggests new unsupervised axes discovery.

Abstract

We present the Score-based Autoencoder for Multiscale Inference (SAMI), a method for unsupervised representation learning that combines the theoretical frameworks of diffusion models and VAEs. By unifying their respective evidence lower bounds, SAMI formulates a principled objective that learns representations through score-based guidance of the underlying diffusion process. The resulting representations automatically capture meaningful structure in the data: it recovers ground truth generative factors in our synthetic dataset, learns factorized, semantic latent dimensions from complex natural images, and encodes video sequences into latent trajectories that are straighter than those of alternative encoders, despite training exclusively on static images. Furthermore, SAMI can extract useful representations from pre-trained diffusion models with minimal additional training. Finally, the explicitly probabilistic formulation provides new ways to identify semantically meaningful axes in the absence of supervised labels, and its mathematical exactness allows us to make formal statements about the nature of the learned representation. Overall, these results indicate that implicit structural information in diffusion models can be made explicit and interpretable through synergistic combination with a variational autoencoder.

Disentangled representations via score-based variational autoencoders

TL;DR

Disentangled representations are learned without supervision by integrating diffusion-based generative modeling with variational autoencoders in SAMI. The key idea is to use conditional diffusion as the VAE's generator and to derive an exact ELBO that permits reusing the inference network to guide diffusion; this leads to latent spaces with semantically meaningful axes and smoother trajectories for video. Empirically, SAMI extracts factorized latent factors from synthetic disks and CelebA, achieves competitive sample quality, and can obtain semantic axes from pretrained diffusion models. The work provides theoretical insight into how diffusion priors induce structure in latent representations and suggests new unsupervised axes discovery.

Abstract

We present the Score-based Autoencoder for Multiscale Inference (SAMI), a method for unsupervised representation learning that combines the theoretical frameworks of diffusion models and VAEs. By unifying their respective evidence lower bounds, SAMI formulates a principled objective that learns representations through score-based guidance of the underlying diffusion process. The resulting representations automatically capture meaningful structure in the data: it recovers ground truth generative factors in our synthetic dataset, learns factorized, semantic latent dimensions from complex natural images, and encodes video sequences into latent trajectories that are straighter than those of alternative encoders, despite training exclusively on static images. Furthermore, SAMI can extract useful representations from pre-trained diffusion models with minimal additional training. Finally, the explicitly probabilistic formulation provides new ways to identify semantically meaningful axes in the absence of supervised labels, and its mathematical exactness allows us to make formal statements about the nature of the learned representation. Overall, these results indicate that implicit structural information in diffusion models can be made explicit and interpretable through synergistic combination with a variational autoencoder.

Paper Structure

This paper contains 42 sections, 103 equations, 7 figures, 3 tables, 2 algorithms.

Figures (7)

  • Figure 1: A) Graphical model of SAMI, contrasted with those of standard VAEs and unconditioned diffusion models. B) Schematic of conditional sampling procedure. C) Schematic illustration of how movement in the latent space results in guidance of the denoiser score towards regions of the clean image manifold that are semantically similar to the original image. For visualization, both latent and image spaces are depicted as two dimensional.
  • Figure 2: Disks dataset.A) Graphical model for the disks dataset, with coordinates of disk center ($c_x$, $c_y$) and background intensity $I_{b}$ as latents. B) Random draws from the ground truth generative process. C) Samples drawn from the model, conditioned on leftmost image. D) Posterior means for a grid of test $c_x$ and $c_y$ ground truth positions, fixed $I_{b}$. E) Same as D, but $c_y$ fixed during interpolation of the other two factors. F) Same as D, for all three factors.
  • Figure 3: CelebA.A) Conditional image generation; conditioning image -- red, samples -- blue. B) Linear traversal between latents of leftmost (yellow) and rightmost (purple) images. C) The spectrum of the posterior covariance for one example image from the test set. D) The posterior variance as a function of noise level; color intensity marks the ordering of the axes at the largest noise level. E) Moving along the linear combination of two latent axes with identified semantics. F) Geometric characterization of latent representation. Left: across-test data variability in posterior mean (blue) and posterior variance of a noisy image (orange) for each latent dimension. Middle: norm of the sensitivity of posterior variance to input changes. Right: Global coherence of each axis. The x axes of all three plots are sorted by global coherence. G) Example transformation of two different images along a latent axis identified as global. H) Weight sparsity of binary classifiers compares how closely the semantics of latent axes align to supervised labels for SAMI vs. DiffAE.
  • Figure 4: Unsupervised extraction of latent representation from a pre-trained diffusion model.A) Top: Samples from the trained model, conditioned on leftmost image (green). Bottom: samples drawn from the unconditional diffusion model. B) Straightness of latent trajectories over the course of one naturalistic video from CelebV. C) Average straightness of latent encoding (over 50 naturalistic videos) for SAMI, compared to trajectories in pixel space and latent trajectories of DiffAE.
  • Figure 5: Reduction in variance from conditioning on the posterior on CelebA.
  • ...and 2 more figures