Table of Contents
Fetching ...

Tripod: Three Complementary Inductive Biases for Disentangled Representation Learning

Kyle Hsu, Jubayer Ibn Hamid, Kaylee Burns, Chelsea Finn, Jiajun Wu

TL;DR

This paper tackles identifiability in unsupervised disentangled representation learning by proposing Tripod, which combines three complementary inductive biases—finite scalar latent quantization, kernel-based latent multiinformation regularization, and a normalized Hessian penalty—within a deterministic autoencoder. Each bias is adapted to overcome optimization challenges: fixed-codebook FSQ removes quantization learning losses, KDE-based KLM enables density-based multiinformation in deterministic settings, and NHP provides scale-invariant curvature regularization. Empirical results on four image-disentanglement benchmarks show state-of-the-art performance and demonstrate that all three legs are necessary, with ablations indicating substantial drops when any leg is removed or when naive combinations are used. The work highlights a practical path to stronger disentanglement by reengineering and combining existing inductive biases, at the cost of increased compute, and opens avenues for automatic quantization tuning and broader modality applications.

Abstract

Inductive biases are crucial in disentangled representation learning for narrowing down an underspecified solution set. In this work, we consider endowing a neural network autoencoder with three select inductive biases from the literature: data compression into a grid-like latent space via quantization, collective independence amongst latents, and minimal functional influence of any latent on how other latents determine data generation. In principle, these inductive biases are deeply complementary: they most directly specify properties of the latent space, encoder, and decoder, respectively. In practice, however, naively combining existing techniques instantiating these inductive biases fails to yield significant benefits. To address this, we propose adaptations to the three techniques that simplify the learning problem, equip key regularization terms with stabilizing invariances, and quash degenerate incentives. The resulting model, Tripod, achieves state-of-the-art results on a suite of four image disentanglement benchmarks. We also verify that Tripod significantly improves upon its naive incarnation and that all three of its "legs" are necessary for best performance.

Tripod: Three Complementary Inductive Biases for Disentangled Representation Learning

TL;DR

This paper tackles identifiability in unsupervised disentangled representation learning by proposing Tripod, which combines three complementary inductive biases—finite scalar latent quantization, kernel-based latent multiinformation regularization, and a normalized Hessian penalty—within a deterministic autoencoder. Each bias is adapted to overcome optimization challenges: fixed-codebook FSQ removes quantization learning losses, KDE-based KLM enables density-based multiinformation in deterministic settings, and NHP provides scale-invariant curvature regularization. Empirical results on four image-disentanglement benchmarks show state-of-the-art performance and demonstrate that all three legs are necessary, with ablations indicating substantial drops when any leg is removed or when naive combinations are used. The work highlights a practical path to stronger disentanglement by reengineering and combining existing inductive biases, at the cost of increased compute, and opens avenues for automatic quantization tuning and broader modality applications.

Abstract

Inductive biases are crucial in disentangled representation learning for narrowing down an underspecified solution set. In this work, we consider endowing a neural network autoencoder with three select inductive biases from the literature: data compression into a grid-like latent space via quantization, collective independence amongst latents, and minimal functional influence of any latent on how other latents determine data generation. In principle, these inductive biases are deeply complementary: they most directly specify properties of the latent space, encoder, and decoder, respectively. In practice, however, naively combining existing techniques instantiating these inductive biases fails to yield significant benefits. To address this, we propose adaptations to the three techniques that simplify the learning problem, equip key regularization terms with stabilizing invariances, and quash degenerate incentives. The resulting model, Tripod, achieves state-of-the-art results on a suite of four image disentanglement benchmarks. We also verify that Tripod significantly improves upon its naive incarnation and that all three of its "legs" are necessary for best performance.
Paper Structure (23 sections, 2 theorems, 35 equations, 10 figures, 6 tables)

This paper contains 23 sections, 2 theorems, 35 equations, 10 figures, 6 tables.

Key Result

Proposition 3.1

The Hessian penalty can be reduced by scaling down $\hat{g}{}_k$ or scaling up any $z_j, j \in [{n_{z}}]$, and vice versa. In contrast, the normalized Hessian penalty is invariant to the scaling of $\hat{g}^{[k]}$ and $z_j \ \forall j \in [{n_{z}}]$.

Figures (10)

  • Figure 1: Each of the three inductive biases for disentanglement we consider in this work specifies a different set of preferred models (circles). In principle, using them in conjunction should more precisely specify the desired solution set and better recover models akin to the true data-generating process. Our method, Tripod, makes crucial modifications to these three "legs" to realize this synergy in practice. Code is available at https://github.com/kylehkhsu/tripod.
  • Figure 2: The evolution of discrete latent space structure in autoencoders. We use finite scalar quantization (bottom right) instead of latent quantization (bottom left) so that the codebook values need not be learned.
  • Figure 3: Kernel density estimation facilitates regularizing deterministic quantized latents from having nonzero multiinformation (left) towards collective independence (right). The multiinformation estimation smoothly depends on the latents through distances between samples (\ref{['eq:kde_joint']}, \ref{['eq:kde_marginal']}). The smoothing matrix $S$\ref{['eq:silvermans']} is visualized with a level set (ellipse) of each latent sample's kernel density, and incorporates each dimension's scale (ellipse major and minor axes). The visualized joint densities illustrate the result of accumulating latent sample kernel densities at each grid point.
  • Figure 4: The Hessian penalty is supposed to specify a preference for decoders, such as the one depicted, in which change along one latent (object shape) minimally affects how another latent (horizontal end-effector position) influences data generation. We modify the Hessian penalty to quash degenerate solutions that compromise this intended outcome in autoencoders.
  • Figure 5: Qualitative study of Tripod and naive Tripod on Isaac3D. Decoded latent interventions (top): in each column, we encode an image and visualize the effect of intervening on a single latent on decoding by varying its value in a linear interpolation in that latent's range. Normalized mutual information heatmaps (bottom): this acts as an "answer key" to what the observed qualitative changes in a column should be when considering the entire dataset. Red latents are inactive and corresponding columns are removed from the latent intervention visualizations. For more qualitative results, see \ref{['app:qualitative_results']}.
  • ...and 5 more figures

Theorems & Definitions (6)

  • Proposition 3.1
  • proof
  • Proposition 3.2
  • proof
  • proof
  • proof