Table of Contents
Fetching ...

Improving the Reconstruction of Disentangled Representation Learners via Multi-Stage Modeling

Akash Srivastava, Yamini Bansal, Yukun Ding, Cole Lincoln Hurwitz, Kai Xu, Bernhard Egger, Prasanna Sattigeri, Joshua B. Tenenbaum, Agus Sudjianto, Phuong Le, Arun Prakash R, Nengfeng Zhou, Joel Vaughan, Yaqun Wang, Anwesha Bhattacharyya, Kristjan Greenewald, David D. Cox, Dan Gutfreund

TL;DR

A novel multi-stage modelling approach where the disentangled factors are first learned using a preexisting disentangling representation learning method (such as $\beta$-TCVAE); then, the low-quality reconstruction is improved with another deep generative model that is trained to model the missing correlated latent variables.

Abstract

Current autoencoder-based disentangled representation learning methods achieve disentanglement by penalizing the (aggregate) posterior to encourage statistical independence of the latent factors. This approach introduces a trade-off between disentangled representation learning and reconstruction quality since the model does not have enough capacity to learn correlated latent variables that capture detail information present in most image data. To overcome this trade-off, we present a novel multi-stage modeling approach where the disentangled factors are first learned using a penalty-based disentangled representation learning method; then, the low-quality reconstruction is improved with another deep generative model that is trained to model the missing correlated latent variables, adding detail information while maintaining conditioning on the previously learned disentangled factors. Taken together, our multi-stage modelling approach results in a single, coherent probabilistic model that is theoretically justified by the principal of D-separation and can be realized with a variety of model classes including likelihood-based models such as variational autoencoders, implicit models such as generative adversarial networks, and tractable models like normalizing flows or mixtures of Gaussians. We demonstrate that our multi-stage model has higher reconstruction quality than current state-of-the-art methods with equivalent disentanglement performance across multiple standard benchmarks. In addition, we apply the multi-stage model to generate synthetic tabular datasets, showcasing an enhanced performance over benchmark models across a variety of metrics. The interpretability analysis further indicates that the multi-stage model can effectively uncover distinct and meaningful features of variations from which the original distribution can be recovered.

Improving the Reconstruction of Disentangled Representation Learners via Multi-Stage Modeling

TL;DR

A novel multi-stage modelling approach where the disentangled factors are first learned using a preexisting disentangling representation learning method (such as -TCVAE); then, the low-quality reconstruction is improved with another deep generative model that is trained to model the missing correlated latent variables.

Abstract

Current autoencoder-based disentangled representation learning methods achieve disentanglement by penalizing the (aggregate) posterior to encourage statistical independence of the latent factors. This approach introduces a trade-off between disentangled representation learning and reconstruction quality since the model does not have enough capacity to learn correlated latent variables that capture detail information present in most image data. To overcome this trade-off, we present a novel multi-stage modeling approach where the disentangled factors are first learned using a penalty-based disentangled representation learning method; then, the low-quality reconstruction is improved with another deep generative model that is trained to model the missing correlated latent variables, adding detail information while maintaining conditioning on the previously learned disentangled factors. Taken together, our multi-stage modelling approach results in a single, coherent probabilistic model that is theoretically justified by the principal of D-separation and can be realized with a variety of model classes including likelihood-based models such as variational autoencoders, implicit models such as generative adversarial networks, and tractable models like normalizing flows or mixtures of Gaussians. We demonstrate that our multi-stage model has higher reconstruction quality than current state-of-the-art methods with equivalent disentanglement performance across multiple standard benchmarks. In addition, we apply the multi-stage model to generate synthetic tabular datasets, showcasing an enhanced performance over benchmark models across a variety of metrics. The interpretability analysis further indicates that the multi-stage model can effectively uncover distinct and meaningful features of variations from which the original distribution can be recovered.

Paper Structure

This paper contains 60 sections, 5 equations, 49 figures, 7 tables.

Figures (49)

  • Figure 1: Image reconstruction using $\beta$-TCVAE (Figure \ref{['fig:panel-tcvae']}) and MS-VAE (Figure \ref{['fig:panel-dsvae']}). MS-VAE is able to take the blurry output of the underlying $\beta$-TCVAE model and learn to render a much better approximation of the target while maintaining the pose of the original image (Figure \ref{['fig:panel-target']}).
  • Figure 2: (a) Graphical model of a standard VAE where $C$ and $Z$ are not independent conditioned on $X$. (b) Graphical model of $\beta$-TCVAE where the reconstruction only depends on the independent latent factors $C$. (c) MS-VAE graphical model where $C$ and $Z$ are independent conditioned on $Y$. (d) Schematic for MS-VAE when implemented as a convolutional architecture. Both $Y$ and $X$ are the reconstructions of the same image.
  • Figure 3: FID (lower is better) and MIG (higher is better) comparison of $\beta$-TCVAE, $\beta$-TCVAE-L, and MS-VAE models. On both datasets, MS-VAE is able to consistently improve the reconstruction quality of its underlying $\beta$-TCVAE model while achieving a better MIG than $\beta$-TCVAE-L . We also provide FID and MIG results for Lezama's model lezama2018overcoming with $\beta=1,10$ as well as FID for a vanilla VAE model of the same capacity as MS-VAE (denoted Big-VAE).
  • Figure 4: Mutual Information (MI) between inferred independent factors from the true image X (using $\beta$-TCVAE) and independent factors from various reconstructions of X. Please note that Blue $=M_1$, Red $=M_2$, Green $=M_3$, and Purple $=M_4$ (see Section \ref{['sec:experiments']} for their definitions).
  • Figure 5: Object property prediction for SmallNORB using inferred representations $C$, $Z$, and $C+Z$ as input to a MLP. The accuracy is highest for the MLP trained with $C+Z$.
  • ...and 44 more figures