Table of Contents
Fetching ...

Disentanglement Learning via Topology

Nikita Balabin, Daria Voronkova, Ilya Trofimov, Evgeny Burnaev, Serguei Barannikov

TL;DR

This work introduces TopDis, a topological regularizer for disentangled representation learning that operates in an unsupervised setting and remains effective even when factors of variation are correlated. By integrating a differentiable Representation Topology Divergence (RTD) term into VAE-type losses and employing Gaussian-preserving latent shifts via group(oid) actions, TopDis enforces topological similarity across latent traversals. A gradient orthogonalization step safeguards reconstruction quality while encouraging disentanglement. Across multiple benchmarks (dSprites, 3D Shapes, 3D Faces, MPI 3D, CelebA) and several VAE variants, TopDis consistently improves MIG, FactorVAE score, SAP, and DCI disentanglement metrics while preserving reconstruction, and it can uncover disentangled directions in pretrained StyleGAN. The approach offers a flexible, topology-driven inductive bias for disentanglement with potential applicability to diverse domains beyond images.

Abstract

We propose TopDis (Topological Disentanglement), a method for learning disentangled representations via adding a multi-scale topological loss term. Disentanglement is a crucial property of data representations substantial for the explainability and robustness of deep learning models and a step towards high-level cognition. The state-of-the-art methods are based on VAE and encourage the joint distribution of latent variables to be factorized. We take a different perspective on disentanglement by analyzing topological properties of data manifolds. In particular, we optimize the topological similarity for data manifolds traversals. To the best of our knowledge, our paper is the first one to propose a differentiable topological loss for disentanglement learning. Our experiments have shown that the proposed TopDis loss improves disentanglement scores such as MIG, FactorVAE score, SAP score, and DCI disentanglement score with respect to state-of-the-art results while preserving the reconstruction quality. Our method works in an unsupervised manner, permitting us to apply it to problems without labeled factors of variation. The TopDis loss works even when factors of variation are correlated. Additionally, we show how to use the proposed topological loss to find disentangled directions in a trained GAN.

Disentanglement Learning via Topology

TL;DR

This work introduces TopDis, a topological regularizer for disentangled representation learning that operates in an unsupervised setting and remains effective even when factors of variation are correlated. By integrating a differentiable Representation Topology Divergence (RTD) term into VAE-type losses and employing Gaussian-preserving latent shifts via group(oid) actions, TopDis enforces topological similarity across latent traversals. A gradient orthogonalization step safeguards reconstruction quality while encouraging disentanglement. Across multiple benchmarks (dSprites, 3D Shapes, 3D Faces, MPI 3D, CelebA) and several VAE variants, TopDis consistently improves MIG, FactorVAE score, SAP, and DCI disentanglement metrics while preserving reconstruction, and it can uncover disentangled directions in pretrained StyleGAN. The approach offers a flexible, topology-driven inductive bias for disentanglement with potential applicability to diverse domains beyond images.

Abstract

We propose TopDis (Topological Disentanglement), a method for learning disentangled representations via adding a multi-scale topological loss term. Disentanglement is a crucial property of data representations substantial for the explainability and robustness of deep learning models and a step towards high-level cognition. The state-of-the-art methods are based on VAE and encourage the joint distribution of latent variables to be factorized. We take a different perspective on disentanglement by analyzing topological properties of data manifolds. In particular, we optimize the topological similarity for data manifolds traversals. To the best of our knowledge, our paper is the first one to propose a differentiable topological loss for disentanglement learning. Our experiments have shown that the proposed TopDis loss improves disentanglement scores such as MIG, FactorVAE score, SAP score, and DCI disentanglement score with respect to state-of-the-art results while preserving the reconstruction quality. Our method works in an unsupervised manner, permitting us to apply it to problems without labeled factors of variation. The TopDis loss works even when factors of variation are correlated. Additionally, we show how to use the proposed topological loss to find disentangled directions in a trained GAN.
Paper Structure (41 sections, 3 theorems, 16 equations, 24 figures, 14 tables, 2 algorithms)

This paper contains 41 sections, 3 theorems, 16 equations, 24 figures, 14 tables, 2 algorithms.

Key Result

Proposition 4.1

a) For any fixed $\rho,\sigma$, the equation (shifts) defines a local action of the additive group $\{C\;\vert\; C\in\mathbb{R}\}$ on real line. b) This abelian group(oid) action preserves the $N(\rho,\sigma^2)$ Gaussian distribution density. c) Conversely, if a local (group(oid)) action of this abe

Figures (24)

  • Figure 1: The TopDis pipeline process involves the following steps: encoding a batch of data samples, applying shift in a latent code, decoding both the original and the shifted latents, and finally calculating the TopDis loss between the two resulting point clouds, for details see Section \ref{['sec:method']}.
  • Figure 2: An example of RTD calculation.
  • Figure 3: Left: rows represent point clouds (mini-batches). The 1st row represents a random batch of samples; the 2nd row is obtained by equally shifting samples from the 1st row to the right; the 3rd row is placed the same as 2nd, but all objects are randomly transformed; the 4th row is a scaling of samples from 3rd row. The RTD value between the 1st and 2nd point clouds is zero, as RTD between the 3rd and 4th rows. While RTD between the 2nd and 3rd rows is large because the topological structures of these two clouds are not similar.
  • Figure 4: Shift of real line preserving ${N}(0, 1)$, $C=1/8$. The three orange curvilinear rectangles have the same area: $F(z_{\operatorname{shifted}})-F(z)=1/8$
  • Figure 5: FactorVAE (left) and FactorVAE + TopDis (right) latent traversals on 3D Shapes.
  • ...and 19 more figures

Theorems & Definitions (4)

  • Proposition 4.1
  • Proposition 4.2
  • Remark 2.1
  • Proposition 21.1