Sparsity regularization via tree-structured environments for disentangled representations
Elliot Layne, Jason Hartford, Sébastien Lachapelle, Mathieu Blanchette, Dhanya Sridhar
TL;DR
The paper tackles learning disentangled latent variables from measurements collected across related environments arranged as a tree, a setting common in biology. It introduces Tree-Based Regularization (TBR), which jointly learns an encoder, a root parameter $\hat{w}_0$, and edge mutations $\hat{\Delta}$ by minimizing prediction error across leaf environments while enforcing sparsity on $\hat{\Delta}$ to reflect sparse changes in $P_e(Y|\mathbf{Z})$. The authors prove identifiability results: latent variables are identifiable up to permutation and scaling under the proposed sparsity regime, with a two-part argument that first handles nonlinear encoders via linear identifiability and then leverages $L_0$ sparsity to rule out entangled solutions. Empirically, TBR outperforms a baseline in disentangling latents (MCC), maintains lower predictive error across varying sparsity, and improves generalization to unseen environments, including in single-cell RNA-seq settings. These findings highlight the potential of exploiting hierarchical, sparse environment changes to enable causal discovery and robust transfer in complex biological data, while outlining directions to relax linearity and extend the theory.
Abstract
Many causal systems such as biological processes in cells can only be observed indirectly via measurements, such as gene expression. Causal representation learning -- the task of correctly mapping low-level observations to latent causal variables -- could advance scientific understanding by enabling inference of latent variables such as pathway activation. In this paper, we develop methods for inferring latent variables from multiple related datasets (environments) and tasks. As a running example, we consider the task of predicting a phenotype from gene expression, where we often collect data from multiple cell types or organisms that are related in known ways. The key insight is that the mapping from latent variables driven by gene expression to the phenotype of interest changes sparsely across closely related environments. To model sparse changes, we introduce Tree-Based Regularization (TBR), an objective that minimizes both prediction error and regularizes closely related environments to learn similar predictors. We prove that under assumptions about the degree of sparse changes, TBR identifies the true latent variables up to some simple transformations. We evaluate the theory empirically with both simulations and ground-truth gene expression data. We find that TBR recovers the latent causal variables better than related methods across these settings, even under settings that violate some assumptions of the theory.
