Table of Contents
Fetching ...

Spatial Broadcast Decoder: A Simple Architecture for Learning Disentangled Representations in VAEs

Nicholas Watters, Loic Matthey, Christopher P. Burgess, Alexander Lerchner

TL;DR

The paper addresses disentangled representation learning in VAEs by proposing the Spatial Broadcast decoder, which tiles a latent vector across a spatial grid and appends fixed coordinate channels before a shallow unstrided convolutional decoder. This architectural prior enables the model to separate positional from non-positional features without supervision, improving disentanglement, reconstruction, and generalization, especially for small objects. Extensive experiments on colored sprites, Chairs, and 3D Object-in-Room show consistent gains in MIG and qualitative disentanglement, along with a simple latent-space visualization technique that clarifies latent geometry. The method is complementary to state-of-the-art disentangling approaches like FactorVAE and $eta$-VAE and can be integrated to boost their performance with minimal hyperparameter tuning.

Abstract

We present a simple neural rendering architecture that helps variational autoencoders (VAEs) learn disentangled representations. Instead of the deconvolutional network typically used in the decoder of VAEs, we tile (broadcast) the latent vector across space, concatenate fixed X- and Y-"coordinate" channels, and apply a fully convolutional network with 1x1 stride. This provides an architectural prior for dissociating positional from non-positional features in the latent distribution of VAEs, yet without providing any explicit supervision to this effect. We show that this architecture, which we term the Spatial Broadcast decoder, improves disentangling, reconstruction accuracy, and generalization to held-out regions in data space. It provides a particularly dramatic benefit when applied to datasets with small objects. We also emphasize a method for visualizing learned latent spaces that helped us diagnose our models and may prove useful for others aiming to assess data representations. Finally, we show the Spatial Broadcast Decoder is complementary to state-of-the-art (SOTA) disentangling techniques and when incorporated improves their performance.

Spatial Broadcast Decoder: A Simple Architecture for Learning Disentangled Representations in VAEs

TL;DR

The paper addresses disentangled representation learning in VAEs by proposing the Spatial Broadcast decoder, which tiles a latent vector across a spatial grid and appends fixed coordinate channels before a shallow unstrided convolutional decoder. This architectural prior enables the model to separate positional from non-positional features without supervision, improving disentanglement, reconstruction, and generalization, especially for small objects. Extensive experiments on colored sprites, Chairs, and 3D Object-in-Room show consistent gains in MIG and qualitative disentanglement, along with a simple latent-space visualization technique that clarifies latent geometry. The method is complementary to state-of-the-art disentangling approaches like FactorVAE and -VAE and can be integrated to boost their performance with minimal hyperparameter tuning.

Abstract

We present a simple neural rendering architecture that helps variational autoencoders (VAEs) learn disentangled representations. Instead of the deconvolutional network typically used in the decoder of VAEs, we tile (broadcast) the latent vector across space, concatenate fixed X- and Y-"coordinate" channels, and apply a fully convolutional network with 1x1 stride. This provides an architectural prior for dissociating positional from non-positional features in the latent distribution of VAEs, yet without providing any explicit supervision to this effect. We show that this architecture, which we term the Spatial Broadcast decoder, improves disentangling, reconstruction accuracy, and generalization to held-out regions in data space. It provides a particularly dramatic benefit when applied to datasets with small objects. We also emphasize a method for visualizing learned latent spaces that helped us diagnose our models and may prove useful for others aiming to assess data representations. Finally, we show the Spatial Broadcast Decoder is complementary to state-of-the-art (SOTA) disentangling techniques and when incorporated improves their performance.

Paper Structure

This paper contains 24 sections, 22 figures, 10 tables, 1 algorithm.

Figures (22)

  • Figure 1: (left) Schematic of the Spatial Broadcast VAE. In the decoder, we broadcast (tile) a latent sample of size $k$ to the image width $w$ and height $h$, and concatenate two "coordinate" channels. This is then fed to an unstrided convolutional decoder. (right) Pseudo-code of the spatial broadcast operation, assuming access to a numpy / Tensorflow-like API.
  • Figure 2: Comparing Deconv to Spatial Broadcast decoder in a VAE.(left) MIG results, showing a Spatial Broadcast VAE achieves higher (better) scores than a DeConv VAE. Stars are median MIG values and the seeds used for the traversals on the right. (middle) DeConv VAE reconstructions and latent space traversals. Traversals are generated around a seed point in latent space by reconstructing a sweep from -2 to +2 for each coordinate while keeping all other coordinates constant. The traversal shows an entangled representation in this model. (right) Spatial Broadcast VAE reconstructions and traversal. The traversal is well-disentangled and aligned with generative factors, as indicated by the labels on the right (which were attributed by visual inspection). While all models were trained with 10 latent coordinates, only the $8$ lowest-variance ones are shown in the traversals (the remainder are non-coding coordinates).
  • Figure 3: Comparing Deconv to Spatial Broadcast decoder in a FactorVAE.(left) MIG results, showing a Spatial Broadcast FactorVAE acheives higher (better) scores than a DeConv FactorVAE. Stars are median MIG values and the seeds used for the traversals on the right. (middle) DeConv FactorVAE reconstructions and entangled latent space traversals. (right) Spatial Broadcast FactorVAE reconstructions and traversal. The traversal is well-disentangled. As in Figure \ref{['fig:colored_sprites:vae']}, only the most relevant 8 of each model's 10 latent coordinates are shown in the traversals.
  • Figure 4: Rate-distortion proxy curves. We swept $\beta$ log-linearly from 0.4 to 5.4 and for each value trained 10 replicas each of Deconv $\beta$-VAE (blue) and Spatial Broadcast $\beta$-VAE (orange) on colored sprites. The dots show the mean over these replicas for each $\beta$, and the shaded region shows the hull of one standard deviation. White dots indicate $\beta=1$. (a) Reconstruction (Negative Log-Likelihood, NLL) vs KL. $\beta < 1$ yields low NLL and high KL (bottom-right of figure), whereas $\beta > 1$ yields high NLL and low KL (top-left of figure). See Alemi2017 for details. Spatial Broadcast $\beta$-VAE shows a better rate-distortion curve than Deconv $\beta$-VAE. (b) Reconstruction vs MIG metric. $\beta < 1$ correspond to lower NLL and low MIG regions (bottom-left of figure), and $\beta > 1$ values correspond to high NLL and high MIG scores (towards top-right of figure). Spatial Broadcast $\beta$-VAE is better disentangled (higher MIG scores) than Deconv $\beta$-VAE.
  • Figure 5: Traversals for datasets with no positional variation. A Spatial Broadcast VAE shows good reconstructions and disentangling on the Chairs dataset Aubry_etal_2014 and the 3D Object-in-Room dataset Kim_Mnih_2017. As in Figures \ref{['fig:colored_sprites:vae']} and \ref{['fig:colored_sprites:factorvae']}, the models have 10 latent coordinates, though in these traversals the 4 non-coding ones are omitted.
  • ...and 17 more figures