Table of Contents
Fetching ...

Supervised Contrastive Block Disentanglement

Taro Makino, Ji Won Park, Natasa Tagasovska, Takamasa Kudo, Paula Coelho, Jan-Christian Huetter, Heming Yao, Burkhard Hoeckendorf, Ana Carolina Leote, Stephen Ra, David Richmond, Kyunghyun Cho, Aviv Regev, Romain Lopez

TL;DR

SCBD introduces a two-embedding framework that learns ${\mathbf{z}}_c$ to capture the phenomenon of interest while remaining invariant to environment ${e}$, and ${\mathbf{z}}_s$ to model spurious environment-related variation. Built on Supervised Contrastive Learning, SCBD combines two supervised contrastive terms with a novel invariance loss scaled by ${\alpha}$ and an optional reconstruction loss, avoiding adversarial training. Empirical results show strong out-of-distribution generalization on CMNIST and Camelyon17-WILDS and effective batch correction on Optical Pooled Screen data, with higher ${\alpha}$ yielding better invariance at the cost of in-distribution performance. The work offers a practical, hyperparameter-tunable approach to block disentanglement that outperforms variational baselines like iVAE and supports downstream tasks requiring robust, environment-agnostic representations.

Abstract

Real-world datasets often combine data collected under different experimental conditions. This yields larger datasets, but also introduces spurious correlations that make it difficult to model the phenomena of interest. We address this by learning two embeddings to independently represent the phenomena of interest and the spurious correlations. The embedding representing the phenomena of interest is correlated with the target variable $y$, and is invariant to the environment variable $e$. In contrast, the embedding representing the spurious correlations is correlated with $e$. The invariance to $e$ is difficult to achieve on real-world datasets. Our primary contribution is an algorithm called Supervised Contrastive Block Disentanglement (SCBD) that effectively enforces this invariance. It is based purely on Supervised Contrastive Learning, and applies to real-world data better than existing approaches. We empirically validate SCBD on two challenging problems. The first problem is domain generalization, where we achieve strong performance on a synthetic dataset, as well as on Camelyon17-WILDS. We introduce a single hyperparameter $α$ to control the degree of invariance to $e$. When we increase $α$ to strengthen the degree of invariance, out-of-distribution performance improves at the expense of in-distribution performance. The second problem is batch correction, in which we apply SCBD to preserve biological signal and remove inter-well batch effects when modeling single-cell perturbations from 26 million Optical Pooled Screening images.

Supervised Contrastive Block Disentanglement

TL;DR

SCBD introduces a two-embedding framework that learns to capture the phenomenon of interest while remaining invariant to environment , and to model spurious environment-related variation. Built on Supervised Contrastive Learning, SCBD combines two supervised contrastive terms with a novel invariance loss scaled by and an optional reconstruction loss, avoiding adversarial training. Empirical results show strong out-of-distribution generalization on CMNIST and Camelyon17-WILDS and effective batch correction on Optical Pooled Screen data, with higher yielding better invariance at the cost of in-distribution performance. The work offers a practical, hyperparameter-tunable approach to block disentanglement that outperforms variational baselines like iVAE and supports downstream tasks requiring robust, environment-agnostic representations.

Abstract

Real-world datasets often combine data collected under different experimental conditions. This yields larger datasets, but also introduces spurious correlations that make it difficult to model the phenomena of interest. We address this by learning two embeddings to independently represent the phenomena of interest and the spurious correlations. The embedding representing the phenomena of interest is correlated with the target variable , and is invariant to the environment variable . In contrast, the embedding representing the spurious correlations is correlated with . The invariance to is difficult to achieve on real-world datasets. Our primary contribution is an algorithm called Supervised Contrastive Block Disentanglement (SCBD) that effectively enforces this invariance. It is based purely on Supervised Contrastive Learning, and applies to real-world data better than existing approaches. We empirically validate SCBD on two challenging problems. The first problem is domain generalization, where we achieve strong performance on a synthetic dataset, as well as on Camelyon17-WILDS. We introduce a single hyperparameter to control the degree of invariance to . When we increase to strengthen the degree of invariance, out-of-distribution performance improves at the expense of in-distribution performance. The second problem is batch correction, in which we apply SCBD to preserve biological signal and remove inter-well batch effects when modeling single-cell perturbations from 26 million Optical Pooled Screening images.

Paper Structure

This paper contains 34 sections, 13 equations, 24 figures, 7 tables.

Figures (24)

  • Figure 1: Spurious correlations emerge when collecting medical images from different hospitals, or conducting single-cell perturbation screens across multiple wells. ${\mathbf{z}}_s$ models these spurious correlations, while ${\mathbf{z}}_c$ models the environment-invariant correlations.
  • Figure 2: Colored MNIST. (a) There is an environment-dependent correlation between color and digit on the training set, which does not persist on the test set where all digits are white. (b) We can generate images counterfactually using SCBD. When we swap ${\mathbf{z}}_c$ across examples, it changes the digit without affecting the color. In contrast, when we swap ${\mathbf{z}}_s$ across examples, it changes the color without affecting the digit. By composing digit and color independently, we generate images outside of the support of the training distribution, such as a light red one (bottom middle) and a bright green five (bottom right).
  • Figure 3: Increasing $\alpha$ strengthens the degree that ${\mathbf{z}}_c$ is invariant to $e$, and monotonically improves test accuracy at the expense of validation accuracy.
  • Figure 4: Comparison of SCBD to CellProfiler and VAE-based baselines on real-world batch correction. Left: Performance on predicting protein complex membership (biological content). Higher is better. Right: Performance of predicting the well label $e$. Lower is better. SCBD is unambiguously better than all baselines, as it preserves more biological signal, while being less sensitive to the inter-well batch effects.
  • Figure 5: In- and out-of-distribution performance are negatively correlated on CMNIST, which satisfies the assumptions made by SCBD.
  • ...and 19 more figures