Causal Structure and Representation Learning with Biomedical Applications
Caroline Uhler, Jiaqi Zhang
TL;DR
The paper tackles how to fuse representation learning with causal inference in biomedical settings, emphasizing multi-modal observational and perturbational data. It surveys causal discovery algorithms (PC, GAS, GSP) and their use with interventional data, highlighting identifiability limits under faithfulness and finite samples. It then develops causal representation learning (CRL) frameworks for single-modality, interventional, and multi-modal data, providing identifiability results and practical algorithms (e.g., leaf detection via Jacobians and constrained optimization) to recover latent causal variables and their relations, up to equivalence. The work applies these ideas to gene regulatory networks and Perturb-seq data, and advocates for causal experimental design to efficiently elicit informative perturbations and modalities, with broad implications for accelerating biomedical discovery.
Abstract
Massive data collection holds the promise of a better understanding of complex phenomena and, ultimately, better decisions. Representation learning has become a key driver of deep learning applications, as it allows learning latent spaces that capture important properties of the data without requiring any supervised annotations. Although representation learning has been hugely successful in predictive tasks, it can fail miserably in causal tasks including predicting the effect of a perturbation/intervention. This calls for a marriage between representation learning and causal inference. An exciting opportunity in this regard stems from the growing availability of multi-modal data (observational and perturbational, imaging-based and sequencing-based, at the single-cell level, tissue-level, and organism-level). We outline a statistical and computational framework for causal structure and representation learning motivated by fundamental biomedical questions: how to effectively use observational and perturbational data to perform causal discovery on observed causal variables; how to use multi-modal views of the system to learn causal variables; and how to design optimal perturbations.
