Table of Contents
Fetching ...

Bridging the inference gap in Mutimodal Variational Autoencoders

Agathe Senellart, Stéphanie Allassonnière

TL;DR

This work tackles the inference gap and generation quality in multimodal variational autoencoders by proposing two non-aggregation-based approaches that separately learn a joint generative model and refined unimodal posteriors. The core ideas include using Normalizing Flows to flexibly approximate subset posteriors, sampling subset posteriors with a Product-of-Experts formulation and Hamiltonian Monte Carlo, and optionally leveraging shared information across modalities through pretrained projectors g_j (JNF-Shared). The methods achieve state-of-the-art coherence and competitive diversity on benchmarks such as MNIST-SVHN, PolyMNIST, Translated PolyMNIST, and MHD, while remaining scalable to many modalities. The work also discusses extensions with contrastive or DCCA-based shared representations and highlights practical implications for robust multimodal generation in real-world applications.

Abstract

From medical diagnosis to autonomous vehicles, critical applications rely on the integration of multiple heterogeneous data modalities. Multimodal Variational Autoencoders offer versatile and scalable methods for generating unobserved modalities from observed ones. Recent models using mixturesof-experts aggregation suffer from theoretically grounded limitations that restrict their generation quality on complex datasets. In this article, we propose a novel interpretable model able to learn both joint and conditional distributions without introducing mixture aggregation. Our model follows a multistage training process: first modeling the joint distribution with variational inference and then modeling the conditional distributions with Normalizing Flows to better approximate true posteriors. Importantly, we also propose to extract and leverage the information shared between modalities to improve the conditional coherence of generated samples. Our method achieves state-of-the-art results on several benchmark datasets.

Bridging the inference gap in Mutimodal Variational Autoencoders

TL;DR

This work tackles the inference gap and generation quality in multimodal variational autoencoders by proposing two non-aggregation-based approaches that separately learn a joint generative model and refined unimodal posteriors. The core ideas include using Normalizing Flows to flexibly approximate subset posteriors, sampling subset posteriors with a Product-of-Experts formulation and Hamiltonian Monte Carlo, and optionally leveraging shared information across modalities through pretrained projectors g_j (JNF-Shared). The methods achieve state-of-the-art coherence and competitive diversity on benchmarks such as MNIST-SVHN, PolyMNIST, Translated PolyMNIST, and MHD, while remaining scalable to many modalities. The work also discusses extensions with contrastive or DCCA-based shared representations and highlights practical implications for robust multimodal generation in real-world applications.

Abstract

From medical diagnosis to autonomous vehicles, critical applications rely on the integration of multiple heterogeneous data modalities. Multimodal Variational Autoencoders offer versatile and scalable methods for generating unobserved modalities from observed ones. Recent models using mixturesof-experts aggregation suffer from theoretically grounded limitations that restrict their generation quality on complex datasets. In this article, we propose a novel interpretable model able to learn both joint and conditional distributions without introducing mixture aggregation. Our model follows a multistage training process: first modeling the joint distribution with variational inference and then modeling the conditional distributions with Normalizing Flows to better approximate true posteriors. Importantly, we also propose to extract and leverage the information shared between modalities to improve the conditional coherence of generated samples. Our method achieves state-of-the-art results on several benchmark datasets.

Paper Structure

This paper contains 46 sections, 30 equations, 17 figures, 6 tables, 1 algorithm.

Figures (17)

  • Figure 1: Graphical models in the case $M=2$. Dashed lines represent decoders, solid lines represent encoders, and red arrows represent the projectors extracting shared information. "NF" refers to Normalizing Flows.
  • Figure 2: a) Samples from the toy dataset. b) The joint generative model $p_{\theta}(x_1,x_2)$ has been learned and we visualize the 2-dimensional latent space. Each point encodes a pair of images $(x_1,x_2)$. Here the color of each point, indicates the size and class of the encoded square. We try to approximate the posterior $p_{\theta}(z|x_1)$ of a large square image $x_1$ (shown in the top left), that corresponds to dark blue dots in the latent space. In b), we use a diagonal Gaussian distribution and in c) we use Normalizing Flows. We see that Normalizing Flows capture a realistic posterior where the Gaussian distribution has a support that is too large, leading to unrealistic generation framed in red. d) Using DCCA, we extract the information shared across modalities, which is the shape class: full or empty. We learn $q_{\phi_1}(z|g_1(x_1))$ and see that it approximates well the part of the latent space which encodes full shapes. For b), c), and d) we present samples generated in the circle modality using the learned posterior on the right side of each plot. Both c) and d) produce relevant and diverse samples.
  • Figure 3: On the first row: generation from MNIST to SVHN. On the second row: generation from SVHN to MNIST. On the third row: generation from SVHN to SVHN (unimodal reconstruction). In red, we frame samples where the background is well reconstructed but not the digit. JNF-CL refers to our model JNF-Shared with CL. Note that for this model, when reconstructing SVHN, we sample $z \sim q_{\phi_{2}}(z|g_{2}(x_{2}))$ and therefore the background information is filtered by the projector $g_{2}(x_{2})$ and cannot be reconstructed. However, the digit is well preserved which is what is required for cross-modal generation.
  • Figure 4: In the two left columns, we present results for conditional generation when varying the number of conditioning modalities. In the right column, we display coherence and FID for unconditional generation. Each point correspond to a different training seed. For these plots, best models having high coherence and low FID are in the top left corner. The FID is computed on 10,000 samples of the first modality.
  • Figure 5: Joint generation in all five modalities when sampling a latent code from the prior. In each image, each row corresponds to a modality. JNF-CL (resp. DCCA) correspond to our method JNF-Shared with CL (resp. DCCA).
  • ...and 12 more figures