Table of Contents
Fetching ...

The VampPrior Mixture Model

Andrew A. Stirn, David A. Knowles

TL;DR

The VampPrior Mixture Model (VMM) introduces a Bayesian Gaussian Mixture prior to replace the standard $p(z)=\mathcal{N}(0,I)$ in deep latent variable models, enabling automatic, DP-GMM–like clustering in latent space. It employs an alternating variational inference and empirical Bayes (MAP-EM) procedure to learn variational parameters and prior hyperparameters, while modeling cluster centers as distributions via $\mu_j \sim q_\phi(\mu_j;u_j)$ with widths $\Lambda_j$. Empirical results show that VMM achieves strong clustering on image benchmarks and substantially improves scRNA-seq integration when embedded in scVI, outperforming both the standard Gaussian prior and VampPrior-based approaches in key metrics. The work demonstrates that a flexible DP-GMM prior and cluster-center distributions yield better interpretability and performance, with the VMM providing intermediate complexity between a plain VAE and the VampPrior, and offering tunable cluster granularity for diverse data modalities.

Abstract

Widely used deep latent variable models (DLVMs), in particular Variational Autoencoders (VAEs), employ overly simplistic priors on the latent space. To achieve strong clustering performance, existing methods that replace the standard normal prior with a Gaussian mixture model (GMM) require defining the number of clusters to be close to the number of expected ground truth classes a-priori and are susceptible to poor initializations. We leverage VampPrior concepts (Tomczak and Welling, 2018) to fit a Bayesian GMM prior, resulting in the VampPrior Mixture Model (VMM), a novel prior for DLVMs. In a VAE, the VMM attains highly competitive clustering performance on benchmark datasets. Integrating the VMM into scVI (Lopez et al., 2018), a popular scRNA-seq integration method, significantly improves its performance and automatically arranges cells into clusters with similar biological characteristics.

The VampPrior Mixture Model

TL;DR

The VampPrior Mixture Model (VMM) introduces a Bayesian Gaussian Mixture prior to replace the standard in deep latent variable models, enabling automatic, DP-GMM–like clustering in latent space. It employs an alternating variational inference and empirical Bayes (MAP-EM) procedure to learn variational parameters and prior hyperparameters, while modeling cluster centers as distributions via with widths . Empirical results show that VMM achieves strong clustering on image benchmarks and substantially improves scRNA-seq integration when embedded in scVI, outperforming both the standard Gaussian prior and VampPrior-based approaches in key metrics. The work demonstrates that a flexible DP-GMM prior and cluster-center distributions yield better interpretability and performance, with the VMM providing intermediate complexity between a plain VAE and the VampPrior, and offering tunable cluster granularity for diverse data modalities.

Abstract

Widely used deep latent variable models (DLVMs), in particular Variational Autoencoders (VAEs), employ overly simplistic priors on the latent space. To achieve strong clustering performance, existing methods that replace the standard normal prior with a Gaussian mixture model (GMM) require defining the number of clusters to be close to the number of expected ground truth classes a-priori and are susceptible to poor initializations. We leverage VampPrior concepts (Tomczak and Welling, 2018) to fit a Bayesian GMM prior, resulting in the VampPrior Mixture Model (VMM), a novel prior for DLVMs. In a VAE, the VMM attains highly competitive clustering performance on benchmark datasets. Integrating the VMM into scVI (Lopez et al., 2018), a popular scRNA-seq integration method, significantly improves its performance and automatically arranges cells into clusters with similar biological characteristics.
Paper Structure (19 sections, 16 equations, 7 figures, 4 tables, 1 algorithm)

This paper contains 19 sections, 16 equations, 7 figures, 4 tables, 1 algorithm.

Figures (7)

  • Figure 1: VMM prior predictive samples for MNIST. The number of columns equals the number of utilized clusters. The bottom row shows the data with the highest probability of belonging to the cluster, under which we print the cluster's probability $\pi_j$. The rows above are samples from the corresponding cluster. We sample $\text{round}(10 \cdot \pi_j / \max(\pi))$ images from each component $j$ to help visualize cluster proportions. No samples indicate this value rounded to zero.
  • Figure 2: MDE comparison for the lung atlas dataset. Each row has the same embedding across columns for a tested prior. Columns respectively label points by the technical batch identifier, the annotated cell type, and the prior's cluster assignment.
  • Figure 3: VMM cluster utilization and NMI performance for different batch sizes. Shading denotes 95% CIs. The dotted green rectangle marks the peak NMI performance between class labels and cluster assignments.
  • Figure 4: Prior predictive samples. The number of columns equals the number of utilized clusters. The bottom row shows the data with the highest probability of belonging to the cluster, under which we print the cluster's probability $\pi_j$. The rows above are samples from the corresponding cluster. We sample $\text{round}(10 \cdot \pi_j / \max(\pi))$ images from each component $j$ to help visualize cluster proportions. No samples indicate this value rounded to zero. The VampPrior has uniform prior class probabilities since it specifies $\pi_j=K^{-1}$ and does not fit $\pi$ during inference.
  • Figure 5: MDE comparison for the cortex dataset. Each row has the same embedding across columns for a tested prior. Columns respectively label points by the technical batch identifier, the annotated cell type, and the prior's cluster assignment.
  • ...and 2 more figures