Table of Contents
Fetching ...

ARD-VAE: A Statistical Formulation to Find the Relevant Latent Dimensions of Variational Autoencoders

Surojit Saha, Sarang Joshi, Ross Whitaker

TL;DR

The paper addresses the challenge of choosing latent dimensionality in variational autoencoders by introducing ARD-VAE, which uses a hierarchical prior $p(oldsymbol{z} \,|\, D_{oldsymbol{z}})$ learned from latent-space data $D_{oldsymbol{z}}$ drawn from a subset $oldsymbol{\mathcal{X}}_{oldsymbol{\alpha}}$ of the training data. This automatic relevancy detection prunes unnecessary latent axes, enabling the model to focus on a compact set of active dimensions without manual hyperparameter tuning. Empirical results across MNIST, CelebA, CIFAR10, DSprites, 3D Shapes, and ImageNet demonstrate that ARD-VAE identifies the true factors of variation, achieves favorable FID and precision-recall metrics, and scales to large datasets while maintaining robustness to the initial latent space size. Overall, ARD-VAE offers a practical, data-driven approach to latent dimension selection that improves disentanglement and distribution modeling in VAEs, with clear benefits for real-world large-scale applications such as ImageNet-level data.

Abstract

The variational autoencoder (VAE) is a popular, deep, latent-variable model (DLVM) due to its simple yet effective formulation for modeling the data distribution. Moreover, optimizing the VAE objective function is more manageable than other DLVMs. The bottleneck dimension of the VAE is a crucial design choice, and it has strong ramifications for the model's performance, such as finding the hidden explanatory factors of a dataset using the representations learned by the VAE. However, the size of the latent dimension of the VAE is often treated as a hyperparameter estimated empirically through trial and error. To this end, we propose a statistical formulation to discover the relevant latent factors required for modeling a dataset. In this work, we use a hierarchical prior in the latent space that estimates the variance of the latent axes using the encoded data, which identifies the relevant latent dimensions. For this, we replace the fixed prior in the VAE objective function with a hierarchical prior, keeping the remainder of the formulation unchanged. We call the proposed method the automatic relevancy detection in the variational autoencoder (ARD-VAE). We demonstrate the efficacy of the ARD-VAE on multiple benchmark datasets in finding the relevant latent dimensions and their effect on different evaluation metrics, such as FID score and disentanglement analysis.

ARD-VAE: A Statistical Formulation to Find the Relevant Latent Dimensions of Variational Autoencoders

TL;DR

The paper addresses the challenge of choosing latent dimensionality in variational autoencoders by introducing ARD-VAE, which uses a hierarchical prior learned from latent-space data drawn from a subset of the training data. This automatic relevancy detection prunes unnecessary latent axes, enabling the model to focus on a compact set of active dimensions without manual hyperparameter tuning. Empirical results across MNIST, CelebA, CIFAR10, DSprites, 3D Shapes, and ImageNet demonstrate that ARD-VAE identifies the true factors of variation, achieves favorable FID and precision-recall metrics, and scales to large datasets while maintaining robustness to the initial latent space size. Overall, ARD-VAE offers a practical, data-driven approach to latent dimension selection that improves disentanglement and distribution modeling in VAEs, with clear benefits for real-world large-scale applications such as ImageNet-level data.

Abstract

The variational autoencoder (VAE) is a popular, deep, latent-variable model (DLVM) due to its simple yet effective formulation for modeling the data distribution. Moreover, optimizing the VAE objective function is more manageable than other DLVMs. The bottleneck dimension of the VAE is a crucial design choice, and it has strong ramifications for the model's performance, such as finding the hidden explanatory factors of a dataset using the representations learned by the VAE. However, the size of the latent dimension of the VAE is often treated as a hyperparameter estimated empirically through trial and error. To this end, we propose a statistical formulation to discover the relevant latent factors required for modeling a dataset. In this work, we use a hierarchical prior in the latent space that estimates the variance of the latent axes using the encoded data, which identifies the relevant latent dimensions. For this, we replace the fixed prior in the VAE objective function with a hierarchical prior, keeping the remainder of the formulation unchanged. We call the proposed method the automatic relevancy detection in the variational autoencoder (ARD-VAE). We demonstrate the efficacy of the ARD-VAE on multiple benchmark datasets in finding the relevant latent dimensions and their effect on different evaluation metrics, such as FID score and disentanglement analysis.
Paper Structure (3 sections, 1 equation, 4 figures, 15 tables, 2 algorithms)

This paper contains 3 sections, 1 equation, 4 figures, 15 tables, 2 algorithms.

Figures (4)

  • Figure 1: The minimum and maximum variances estimated by the ARD-VAE while training on the MNIST, CelebA and CIFAR10 datasets for multiple latent dimensions. The maximum estimated variances are orders of magnitude higher than the minimum estimated variances.
  • Figure 2: Latent traversal of the DSprites data set dsprites17 in the range $[-3\sigma, 3\sigma]$ using the relevant axes discovered by the ARD-VAE. The latent factors are mentioned in the left column. All latent factors are represented by independent latent axes with slight entanglement of Shape and Position Y. The MIG score for this model is $\mathbf{0.35}$
  • Figure 3: Latent traversal of the 3D Shapes data set 3dshapes18 in the range $[-3\sigma, 3\sigma]$ using the relevant axes discovered by the ARD-VAE. The latent factors are mentioned in the left column. All latent factors are represented by independent latent axes with almost no overlap between them. The MIG score for this model is $\mathbf{0.84}$
  • Figure 4: Latent traversal of the 3D Shapes data set 3dshapes18 in the range $[-3\sigma, 3\sigma]$ on all the latent axes, $L=10$, used in the training of the ARD-AVE on the 3D Shapes dataset, sorted by the relevance score proposed in the paper. The latent factors are mentioned in the left column for the relevant axes (total $6$), and additional axes (total $4$) are highlighted within the red bounding box that shows no variability in output in response to deviations along these axes. This elucidates our hypothesis about the behavior of irrelevant latent axes.