Table of Contents
Fetching ...

Representation Learning for Distributional Perturbation Extrapolation

Julius von Kügelgen, Jakob Ketterer, Xinwei Shen, Nicolai Meinshausen, Jonas Peters

TL;DR

The paper tackles distributional perturbation extrapolation: predicting the full distribution of omics-like observations under unseen perturbations. It introduces a latent-space model where perturbations act as additive mean shifts ${\bm{Z}}^{\mathrm{pert}}={\bm{Z}}^{\mathrm{base}}+{\bm{W}}{\bm{l}}$ followed by a nonlinear generator ${\bm{f}}({\bm{Z}}^{\mathrm{pert}}, {\bm{\varepsilon}})$, and proves identifiability up to an affine transformation given diverse training perturbations, enabling extrapolation to perturbations in the span of training ${\bm{l}}_e$. The Perturbation Distribution Autoencoder (PDAE) is then proposed to estimate the identifiable components by matching distributional predictions across domains using the energy score, with a workflow consisting of an encoder, a perturbation module, and a stochastic decoder. Empirical results on synthetic data show that PDAE achieves superior distributional fidelity and mean prediction in the in-distribution regime and provides the best-performing extrapolation among methods not explicitly designed for distributional reconstruction, while acknowledging limitations for out-of-distribution generalization and decoder extrapolation. The work advances theory in identifiable representation learning for extrapolation and links to causal modelling by framing perturbations as latent shifts, with practical impact for predicting effects of unseen perturbation combinations in biology and related domains.

Abstract

We consider the problem of modelling the effects of unseen perturbations such as gene knockdowns or drug combinations on low-level measurements such as RNA sequencing data. Specifically, given data collected under some perturbations, we aim to predict the distribution of measurements for new perturbations. To address this challenging extrapolation task, we posit that perturbations act additively in a suitable, unknown embedding space. More precisely, we formulate the generative process underlying the observed data as a latent variable model, in which perturbations amount to mean shifts in latent space and can be combined additively. Unlike previous work, we prove that, given sufficiently diverse training perturbations, the representation and perturbation effects are identifiable up to affine transformation, and use this to characterize the class of unseen perturbations for which we obtain extrapolation guarantees. To estimate the model from data, we propose a new method, the perturbation distribution autoencoder (PDAE), which is trained by maximising the distributional similarity between true and predicted perturbation distributions. The trained model can then be used to predict previously unseen perturbation distributions. Empirical evidence suggests that PDAE compares favourably to existing methods and baselines at predicting the effects of unseen perturbations.

Representation Learning for Distributional Perturbation Extrapolation

TL;DR

The paper tackles distributional perturbation extrapolation: predicting the full distribution of omics-like observations under unseen perturbations. It introduces a latent-space model where perturbations act as additive mean shifts followed by a nonlinear generator , and proves identifiability up to an affine transformation given diverse training perturbations, enabling extrapolation to perturbations in the span of training . The Perturbation Distribution Autoencoder (PDAE) is then proposed to estimate the identifiable components by matching distributional predictions across domains using the energy score, with a workflow consisting of an encoder, a perturbation module, and a stochastic decoder. Empirical results on synthetic data show that PDAE achieves superior distributional fidelity and mean prediction in the in-distribution regime and provides the best-performing extrapolation among methods not explicitly designed for distributional reconstruction, while acknowledging limitations for out-of-distribution generalization and decoder extrapolation. The work advances theory in identifiable representation learning for extrapolation and links to causal modelling by framing perturbations as latent shifts, with practical impact for predicting effects of unseen perturbation combinations in biology and related domains.

Abstract

We consider the problem of modelling the effects of unseen perturbations such as gene knockdowns or drug combinations on low-level measurements such as RNA sequencing data. Specifically, given data collected under some perturbations, we aim to predict the distribution of measurements for new perturbations. To address this challenging extrapolation task, we posit that perturbations act additively in a suitable, unknown embedding space. More precisely, we formulate the generative process underlying the observed data as a latent variable model, in which perturbations amount to mean shifts in latent space and can be combined additively. Unlike previous work, we prove that, given sufficiently diverse training perturbations, the representation and perturbation effects are identifiable up to affine transformation, and use this to characterize the class of unseen perturbations for which we obtain extrapolation guarantees. To estimate the model from data, we propose a new method, the perturbation distribution autoencoder (PDAE), which is trained by maximising the distributional similarity between true and predicted perturbation distributions. The trained model can then be used to predict previously unseen perturbation distributions. Empirical evidence suggests that PDAE compares favourably to existing methods and baselines at predicting the effects of unseen perturbations.

Paper Structure

This paper contains 68 sections, 7 theorems, 79 equations, 5 figures, 2 tables.

Key Result

Theorem 4.1

For $M\in \mathbb{Z}_{\geq 0}$, let ${\bm{l}}_0, ..., {\bm{l}}_M\in\mathbb{R}^K$ be ${M{+}1}$ perturbation labels. Let ${\bm{f}},\widetilde{{\bm{f}}}:\mathbb{R}^{{d_Z}}\to\mathbb{R}^{d_X}$, ${\bm{W}},\widetilde{{\bm{W}}}\in\mathbb{R}^{{d_Z}\times K}$, and $\mathbb{P},\widetilde{\mathbb{P}}$ be distr Assume further that: Then the latent representation and the effects of the observed perturbation c

Figures (5)

  • Figure 1: Task Description and Example Following the Assumed Data Generating Process. (a) During training, we are given $M\!=\!5$ training data sets in observation space (right, contour plots in grey), each of which is generated under a known combination of $K\!=\!3$ elementary perturbations. The corresponding (training) perturbation labels ${\bm{l}}_e$ are shown below the plots. During testing, we are given a new perturbation label and the task is to predict the corresponding distribution in observation space (right, blue and orange). We tackle this task by assuming that the effect of perturbations is linear additive in a suitable latent space (left). Both plots show kernel density estimates of the distributions. (b) For each experiment or environment $e\in[M]_0$, the corresponding dataset comprises a perturbation label $\bm l_e$ and $N_e$ observations $\bm{x}_{e, i}$. Perturbations are assumed to act as mean-shifts on a latent basal state, ${\bm{z}}^\text{pert}_{e,i}={\bm{z}}^\text{base}_{e,i}+{\bm{W}}{\bm{l}}_e$ for an unknown perturbation matrix ${\bm{W}}$. An (unknown) stochastic nonlinear decoder with noise $\bm\varepsilon_{e,i}$ then yields the observed ${\bm{x}}_{e, i}={\bm{f}}({\bm{z}}^\mathrm{pert}_{e,i},\bm\varepsilon_{e,i})$. In the example in (a), ${\bm{f}}$ is constant in its second argument. Shaded and white nodes indicate observed and unobserved/latent variables, respectively.
  • Figure 2: Overview of the Perturbation Distribution Autoencoder (PDAE). In a PDAE, the distribution of a target perturbation condition $h$ (purple) is simulated by encoding, perturbing, and (stochastically) decoding data from a source condition $e$ (blue). Dashed arrows indicate model inputs and green boxes model components with trainable parameters. The learning objective (orange, dotted arrows) consists of: a perturbation loss (bottom), measuring, for all pairs of training domains $(e,h)$, the dissimilarity between the true distribution of ${\bm{X}}_h$ and its simulated version based on domain $e$; and a distributional reconstruction loss (top), measuring, for each source domain $e$, the dissimilarity between the true conditional distribution of source observations ${\bm{X}}_e$ that are mapped to the same encoding $\widehat{{\bm{Z}}}_e$ and the corresponding distribution induced by the stochastic decoder. At test time, the target perturbation label ${\bm{l}}_h$ can be replaced with an unseen ${\bm{l}}_\mathrm{test}$ to make predictions.
  • Figure 3: Qualitative Comparison of PDAE and CPA on Synthetic 2D Data. Rows correspond to latent space (top) and observation space (bottom). Columns show the ground truth data (left), PDAE predictions (center), and CPA predictions (right). Training domains are shown in grey, an in-distribution (ID) test case with ${\bm{l}}^\textsc{id}_\mathrm{test}\!=\!(1,1,0)^\top$ (which overlaps with one of the training domains) in blue, and an out-of-distribution (OOD) test case with ${\bm{l}}_\mathrm{test}^\textsc{ood}\!=\!(1,0,1)^\top$ in orange. All plots show kernel density estimates of the distributions; crosses (x) indicate the corresponding means. As can be seen, PDAE recovers an affine transformation of the true latents (top, center) leading to accurate distributional predictions for the training and ID test domain (bottom, center). However, the OOD test domain is mapped to a part of the latent space not seen during training (top, center). As a result, the corresponding decoder output does not accurately match the true OOD distribution (bottom left vs center). CPA accurately predicts the means of the training distribution (bottom right, black crosses) but does not recover the true latents up to an affine transformation (top right). Moreover, the predicted distributions for CPA do not match the ground truth particularly well, particularly for the test conditions (bottom left vs right). However, recall that---unlike PDAE---CPA is not trained for distributional reconstruction, see \ref{['sec:CPA']} for details.
  • Figure 4: Complete Ground Truth Training and Test Data. Shown are all train and test distributions, with crosses marking the respective means. The training cases are depicted in grey, the ID test cases in blue and the OOD test cases in orange.
  • Figure 5: Evaluation of Robustness Under Increasing Levels of Noise. We compare the ID test performance of the baseline methods, CPA, and PDAE for higher dimensional observations with additional noise dimensions appended while increasing the level of noise $\sigma_\varepsilon$.

Theorems & Definitions (17)

  • Example 2.1: Gene perturbations
  • Example 2.2: Drug perturbations
  • Theorem 4.1: Affine identifiability for Gaussian latents
  • Corollary 4.1: Affine recovery of the perturbation matrix
  • Remark 4.2: Sufficient diversity
  • Remark 4.3: Choice of reference
  • Remark 4.4: Deterministic vs noisy mixing.
  • Theorem 4.5: Extrapolation to span of relative perturbations
  • Remark 4.6: Additive vs linear perturbations
  • Corollary 5.1
  • ...and 7 more