Table of Contents
Fetching ...

Sculpting Latent Spaces With MMD: Disentanglement With Programmable Priors

Quentin Fruytier, Akshay Malhotra, Shahab Hamidi-Rad, Aditya Sant, Aryan Mokhtari, Sujay Sanghavi

TL;DR

The paper addresses the challenge of learning disentangled representations by critiquing KL-based regularization in VAEs and introducing a nonparametric Maximum Mean Discrepancy (MMD) based Programmable Priors framework. It defines the Latent Predictability Score (LPS) as an unsupervised measure of latent independence and demonstrates that the MMD approach can sculpt the aggregate posterior $q_{\theta_1}(z)$ to match a programmable prior $p(z)$, enabling priors such as Gaussian, Uniform, and Gaussian Mixtures. Empirically, the method achieves state-of-the-art mutual independence on CIFAR-10 and Tiny ImageNet without sacrificing reconstruction quality, and shortens the gap between independence and semantic alignment by engineering priors that match interpretable features. The work highlights programmable priors as a powerful tool for representation engineering, with implications for model identifiability and causal representation learning.

Abstract

Learning disentangled representations, where distinct factors of variation are captured by independent latent variables, is a central goal in machine learning. The dominant approach has been the Variational Autoencoder (VAE) framework, which uses a Kullback-Leibler (KL) divergence penalty to encourage the latent space to match a factorized Gaussian prior. In this work, however, we provide direct evidence that this KL-based regularizer is an unreliable mechanism, consistently failing to enforce the target distribution on the aggregate posterior. We validate this and quantify the resulting entanglement using our novel, unsupervised Latent Predictability Score (LPS). To address this failure, we introduce the Programmable Prior Framework, a method built on the Maximum Mean Discrepancy (MMD). Our framework allows practitioners to explicitly sculpt the latent space, achieving state-of-the-art mutual independence on complex datasets like CIFAR-10 and Tiny ImageNet without the common reconstruction trade-off. Furthermore, we demonstrate how this programmability can be used to engineer sophisticated priors that improve alignment with semantically meaningful features. Ultimately, our work provides a foundational tool for representation engineering, opening new avenues for model identifiability and causal reasoning.

Sculpting Latent Spaces With MMD: Disentanglement With Programmable Priors

TL;DR

The paper addresses the challenge of learning disentangled representations by critiquing KL-based regularization in VAEs and introducing a nonparametric Maximum Mean Discrepancy (MMD) based Programmable Priors framework. It defines the Latent Predictability Score (LPS) as an unsupervised measure of latent independence and demonstrates that the MMD approach can sculpt the aggregate posterior to match a programmable prior , enabling priors such as Gaussian, Uniform, and Gaussian Mixtures. Empirically, the method achieves state-of-the-art mutual independence on CIFAR-10 and Tiny ImageNet without sacrificing reconstruction quality, and shortens the gap between independence and semantic alignment by engineering priors that match interpretable features. The work highlights programmable priors as a powerful tool for representation engineering, with implications for model identifiability and causal representation learning.

Abstract

Learning disentangled representations, where distinct factors of variation are captured by independent latent variables, is a central goal in machine learning. The dominant approach has been the Variational Autoencoder (VAE) framework, which uses a Kullback-Leibler (KL) divergence penalty to encourage the latent space to match a factorized Gaussian prior. In this work, however, we provide direct evidence that this KL-based regularizer is an unreliable mechanism, consistently failing to enforce the target distribution on the aggregate posterior. We validate this and quantify the resulting entanglement using our novel, unsupervised Latent Predictability Score (LPS). To address this failure, we introduce the Programmable Prior Framework, a method built on the Maximum Mean Discrepancy (MMD). Our framework allows practitioners to explicitly sculpt the latent space, achieving state-of-the-art mutual independence on complex datasets like CIFAR-10 and Tiny ImageNet without the common reconstruction trade-off. Furthermore, we demonstrate how this programmability can be used to engineer sophisticated priors that improve alignment with semantically meaningful features. Ultimately, our work provides a foundational tool for representation engineering, opening new avenues for model identifiability and causal reasoning.

Paper Structure

This paper contains 33 sections, 5 equations, 13 figures, 15 tables.

Figures (13)

  • Figure 1: Visual comparison of learned latent distributions on MNIST for key baselines. Each column corresponds to a different model, showing the covariance matrix between the latent space features over the dataset (top row) and the histogram plots for the marginal distribution of the latent space features over the whole dataset (bottom row). The VAE-based models (a, b) and the standard AE (e) fail to enforce either independence or the target Gaussian geometry. While a batch-wise KLD regularizer (d) achieves a diagonal covariance, its parametric nature prevents it from correctly shaping the marginal distributions leading to co-dependence (LPS equal to 0.97). Only our MMD-regularized model (c) successfully enforces both properties, producing a representation that is truly independent and matches the target prior.
  • Figure 2: Visual Representation of the Programmable Prior Framework.
  • Figure 3: Visualizing the Latent Space Copying experiment on MNIST. (b) First, a standard Autoencoder was trained, and its complex, entangled latent distribution was saved to serve as an empirical prior. (a) Then, a new "student" model was trained using our MMD regularizer, tasked with replicating this arbitrary prior. The figure displays the covariance matrix between the latent space features over the whole dataset and the histogram plots of the marginal distribution of the latent space features over the whole dataset. The near-identical covariance matrices and marginal distributions visually confirm our method's ability to enforce a complex, co-dependent latent geometry that would be impossible to specify analytically with KLD.
  • Figure 4: Example images (represented in an $8\times8$ grid) from the synthetic datasets used in our experiments. From left to right: (a) dSprites, with factors like shape, scale, and orientation; (b-d) The XY family of datasets, with factors for position (XY), color (C), and shape (S).
  • Figure 5: CIFAR10: Reconstruction and LPS Against Latent Space Size. Line plot showing (left) the reconstruction NMSE in decibels of all baselines and (right) the mutual independence of the latent space features (estimated with our Latent Predictability Score (LPS)) against the dimension of the latent space.
  • ...and 8 more figures