Sculpting Latent Spaces With MMD: Disentanglement With Programmable Priors
Quentin Fruytier, Akshay Malhotra, Shahab Hamidi-Rad, Aditya Sant, Aryan Mokhtari, Sujay Sanghavi
TL;DR
The paper addresses the challenge of learning disentangled representations by critiquing KL-based regularization in VAEs and introducing a nonparametric Maximum Mean Discrepancy (MMD) based Programmable Priors framework. It defines the Latent Predictability Score (LPS) as an unsupervised measure of latent independence and demonstrates that the MMD approach can sculpt the aggregate posterior $q_{\theta_1}(z)$ to match a programmable prior $p(z)$, enabling priors such as Gaussian, Uniform, and Gaussian Mixtures. Empirically, the method achieves state-of-the-art mutual independence on CIFAR-10 and Tiny ImageNet without sacrificing reconstruction quality, and shortens the gap between independence and semantic alignment by engineering priors that match interpretable features. The work highlights programmable priors as a powerful tool for representation engineering, with implications for model identifiability and causal representation learning.
Abstract
Learning disentangled representations, where distinct factors of variation are captured by independent latent variables, is a central goal in machine learning. The dominant approach has been the Variational Autoencoder (VAE) framework, which uses a Kullback-Leibler (KL) divergence penalty to encourage the latent space to match a factorized Gaussian prior. In this work, however, we provide direct evidence that this KL-based regularizer is an unreliable mechanism, consistently failing to enforce the target distribution on the aggregate posterior. We validate this and quantify the resulting entanglement using our novel, unsupervised Latent Predictability Score (LPS). To address this failure, we introduce the Programmable Prior Framework, a method built on the Maximum Mean Discrepancy (MMD). Our framework allows practitioners to explicitly sculpt the latent space, achieving state-of-the-art mutual independence on complex datasets like CIFAR-10 and Tiny ImageNet without the common reconstruction trade-off. Furthermore, we demonstrate how this programmability can be used to engineer sophisticated priors that improve alignment with semantically meaningful features. Ultimately, our work provides a foundational tool for representation engineering, opening new avenues for model identifiability and causal reasoning.
