An Information Criterion for Controlled Disentanglement of Multimodal Data
Chenyu Wang, Sharut Gupta, Xinyi Zhang, Sana Tonekaboni, Stefanie Jegelka, Tommi Jaakkola, Caroline Uhler
TL;DR
Multimodal data often entangles shared and modality-specific information, making $\text{MNI}$ unattainable in real settings. DisentangledSSL introduces a two-step, self-supervised optimization that learns a shared latent $Z_c$ and modality-specific latents $Z_s^1$, $Z_s^2$, guided by an information-theoretic objective and the IB curve to handle both attainable and unattainable MNI. The authors prove optimality guarantees for the learned shared representations and show that modality-specific representations achieve coverage and disentanglement under both regimes, with a tractable training objective combining InfoNCE and MI bounds. Empirically, DisentangledSSL improves downstream performance on vision-language prediction and molecule-phenotype retrieval across synthetic and real-world multimodal datasets, outperforming several baselines and demonstrating robust disentanglement.
Abstract
Multimodal representation learning seeks to relate and decompose information inherent in multiple modalities. By disentangling modality-specific information from information that is shared across modalities, we can improve interpretability and robustness and enable downstream tasks such as the generation of counterfactual outcomes. Separating the two types of information is challenging since they are often deeply entangled in many real-world applications. We propose Disentangled Self-Supervised Learning (DisentangledSSL), a novel self-supervised approach for learning disentangled representations. We present a comprehensive analysis of the optimality of each disentangled representation, particularly focusing on the scenario not covered in prior work where the so-called Minimum Necessary Information (MNI) point is not attainable. We demonstrate that DisentangledSSL successfully learns shared and modality-specific features on multiple synthetic and real-world datasets and consistently outperforms baselines on various downstream tasks, including prediction tasks for vision-language data, as well as molecule-phenotype retrieval tasks for biological data. The code is available at https://github.com/uhlerlab/DisentangledSSL.
