Table of Contents
Fetching ...

Predicting Cellular Responses with Variational Causal Inference and Refined Relational Information

Yulun Wu, Robert A. Barton, Zichen Wang, Vassilis N. Ioannidis, Carlo De Donno, Layne C. Price, Luis F. Voloch, George Karypis

TL;DR

The paper tackles predicting cellular responses to perturbations, a key challenge for drug discovery and personalized therapeutics. It introduces GraphVCI, a graph-structured variational causal model that embeds gene regulatory networks (GRNs) into the latent space to model counterfactual gene expressions, augmented by an adjacency-refinement step and a robust estimator for marginal perturbation effects. Core contributions include a variational objective for counterfactuals, a relational-information mechanism with graph attention, and a scalable GRN refinement procedure that improves edge relevance and model performance, plus an asymptotically efficient estimator for population-level effects. On three benchmark datasets, GraphVCI achieves state-of-the-art out-of-distribution predictions, with ablations confirming the value of refined relational information and biological alignment of learned edges, enabling more accurate and interpretable single-cell perturbation predictions.

Abstract

Predicting the responses of a cell under perturbations may bring important benefits to drug discovery and personalized therapeutics. In this work, we propose a novel graph variational Bayesian causal inference framework to predict a cell's gene expressions under counterfactual perturbations (perturbations that this cell did not factually receive), leveraging information representing biological knowledge in the form of gene regulatory networks (GRNs) to aid individualized cellular response predictions. Aiming at a data-adaptive GRN, we also developed an adjacency matrix updating technique for graph convolutional networks and used it to refine GRNs during pre-training, which generated more insights on gene relations and enhanced model performance. Additionally, we propose a robust estimator within our framework for the asymptotically efficient estimation of marginal perturbation effect, which is yet to be carried out in previous works. With extensive experiments, we exhibited the advantage of our approach over state-of-the-art deep learning models for individual response prediction.

Predicting Cellular Responses with Variational Causal Inference and Refined Relational Information

TL;DR

The paper tackles predicting cellular responses to perturbations, a key challenge for drug discovery and personalized therapeutics. It introduces GraphVCI, a graph-structured variational causal model that embeds gene regulatory networks (GRNs) into the latent space to model counterfactual gene expressions, augmented by an adjacency-refinement step and a robust estimator for marginal perturbation effects. Core contributions include a variational objective for counterfactuals, a relational-information mechanism with graph attention, and a scalable GRN refinement procedure that improves edge relevance and model performance, plus an asymptotically efficient estimator for population-level effects. On three benchmark datasets, GraphVCI achieves state-of-the-art out-of-distribution predictions, with ablations confirming the value of refined relational information and biological alignment of learned edges, enabling more accurate and interpretable single-cell perturbation predictions.

Abstract

Predicting the responses of a cell under perturbations may bring important benefits to drug discovery and personalized therapeutics. In this work, we propose a novel graph variational Bayesian causal inference framework to predict a cell's gene expressions under counterfactual perturbations (perturbations that this cell did not factually receive), leveraging information representing biological knowledge in the form of gene regulatory networks (GRNs) to aid individualized cellular response predictions. Aiming at a data-adaptive GRN, we also developed an adjacency matrix updating technique for graph convolutional networks and used it to refine GRNs during pre-training, which generated more insights on gene relations and enhanced model performance. Additionally, we propose a robust estimator within our framework for the asymptotically efficient estimation of marginal perturbation effect, which is yet to be carried out in previous works. With extensive experiments, we exhibited the advantage of our approach over state-of-the-art deep learning models for individual response prediction.
Paper Structure (22 sections, 2 theorems, 15 equations, 6 figures, 2 tables)

This paper contains 22 sections, 2 theorems, 15 equations, 6 figures, 2 tables.

Key Result

Theorem 1

Suppose that $\mathpzc{W}=(\mathpzc{G}, X, Z, T, T', Y, Y')$ follows a causal structure defined by the Bayesian network in Figure causal_diagram. Then $J(\mathpzc{D})$ has the following variational lower bound: where $D [ p \parallel q ] = \log p - \log q$.

Figures (6)

  • Figure 1: The causal relation diagram. Each individual has a feature state $Z$ following a conditional distribution $p(Z | \mathpzc{G}, X)$. Treatment $T$ (or counterfactual treatment $T'$) along with $Z$ determines outcome $Y$ (or counterfactual outcome $Y'$). In the causal diagram, white nodes are observed and dark grey nodes are unobserved; dashed relations are optional (case dependant). In the context of this paper, graph $\mathpzc{G}$ is a deterministic variable that is invariant across all individuals.
  • Figure 2: Model workflow --- variational causal perspective. In a forward pass, the graphVCI encoder takes graph $\mathpzc{G}$ (e.g. gene relation graph), outcome $Y$ (e.g. gene expressions), covariates $X$ (e.g. cell types, donors, etc.) and treatment $T$ (e.g. drug perturbation) as inputs and generates latent $Z$; $(Z, T)$ and $(Z, T')$ where $T'$ is a randomly sampled counterfactual treatment are separately passed into the graphVCI decoder to attain reconstruction of $Y$ and construction of counterfactual outcome $Y'$; $Y'$ is then passed back into the encoder along with $\mathpzc{G}$, $X$, $T'$ to attain counterfactual latent $Z'$. The objective consists of the reconstruction loss of $Y$, the distribution loss of $Y'$ and the KL-divergence between the conditional distributions of $Z$ and $Z'$.
  • Figure 3: Model architecture --- graph attentional perspective. Structure of the graphVCI encoder and decoder defined by Equations \ref{['enc1', 'enc2', 'enc3', 'enc4', 'dec1', 'dec2']}. Note that in the case of single-cell perturbation datasets, the graph inputs are fixed across samples and graph attention can essentially be reduced to weighted graph convolution.
  • Figure 4: An example of an updated gene regulatory network $U$ (where $U_{i,j} = | \tilde{W}_{i,j} - \alpha |$ with $\alpha=0.2$) after refining the original ATAC-seq-based network kamimoto2020celloracle using the schmidt2022crispr dataset. Source nodes are shown as rows and targets are shown as columns for key immune-related genes. The learned edge weights in (b) recapitulate known biology such as STAT1 regulating IRF1. Also note that while many of the edges are present in the original ATAC-seq data from (a), we see some novel edges in (b) such as IFNg regulating MYC ramana2000regulation.
  • Figure 5: Model predictions versus true distributions for overexpression of genes in CRISPRa experiments schmidt2022crispr. For two perturbations in CD8$^{+}$ T cells, (a) MAP4K1 overexpression and (b) GATA3 overexpression, we plot the distribution of gene expressions for unperturbed cells ("Control"), the model's prediction of perturbed gene expressions using unperturbed cells as factual inputs ("Pred"), and the true gene expressions for perturbed cells ("True"). The predicted distributional shift relative to control often matches the direction of the true shift.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Theorem 1
  • proof
  • Theorem 2
  • proof