Table of Contents
Fetching ...

Causality Pursuit from Heterogeneous Environments via Neural Adversarial Invariance Learning

Yihong Gu, Cong Fang, Peter Bühlmann, Jianqing Fan

TL;DR

It is shown that FAIR-NN can find the invariant variables and quasi-causal variables under a minimal identification condition and that the resulting procedure is adaptive to low-dimensional composition structures in a non-asymptotic analysis.

Abstract

Pursuing causality from data is a fundamental problem in scientific discovery, treatment intervention, and transfer learning. This paper introduces a novel algorithmic method for addressing nonparametric invariance and causality learning in regression models across multiple environments, where the joint distribution of response variables and covariates varies, but the conditional expectations of outcome given an unknown set of quasi-causal variables are invariant. The challenge of finding such an unknown set of quasi-causal or invariant variables is compounded by the presence of endogenous variables that have heterogeneous effects across different environments. The proposed Focused Adversarial Invariant Regularization (FAIR) framework utilizes an innovative minimax optimization approach that drives regression models toward prediction-invariant solutions through adversarial testing. Leveraging the representation power of neural networks, FAIR neural networks (FAIR-NN) are introduced for causality pursuit. It is shown that FAIR-NN can find the invariant variables and quasi-causal variables under a minimal identification condition and that the resulting procedure is adaptive to low-dimensional composition structures in a non-asymptotic analysis. Under a structural causal model, variables identified by FAIR-NN represent pragmatic causality and provably align with exact causal mechanisms under conditions of sufficient heterogeneity. Computationally, FAIR-NN employs a novel Gumbel approximation with decreased temperature and a stochastic gradient descent ascent algorithm. The procedures are demonstrated using simulated and real-data examples.

Causality Pursuit from Heterogeneous Environments via Neural Adversarial Invariance Learning

TL;DR

It is shown that FAIR-NN can find the invariant variables and quasi-causal variables under a minimal identification condition and that the resulting procedure is adaptive to low-dimensional composition structures in a non-asymptotic analysis.

Abstract

Pursuing causality from data is a fundamental problem in scientific discovery, treatment intervention, and transfer learning. This paper introduces a novel algorithmic method for addressing nonparametric invariance and causality learning in regression models across multiple environments, where the joint distribution of response variables and covariates varies, but the conditional expectations of outcome given an unknown set of quasi-causal variables are invariant. The challenge of finding such an unknown set of quasi-causal or invariant variables is compounded by the presence of endogenous variables that have heterogeneous effects across different environments. The proposed Focused Adversarial Invariant Regularization (FAIR) framework utilizes an innovative minimax optimization approach that drives regression models toward prediction-invariant solutions through adversarial testing. Leveraging the representation power of neural networks, FAIR neural networks (FAIR-NN) are introduced for causality pursuit. It is shown that FAIR-NN can find the invariant variables and quasi-causal variables under a minimal identification condition and that the resulting procedure is adaptive to low-dimensional composition structures in a non-asymptotic analysis. Under a structural causal model, variables identified by FAIR-NN represent pragmatic causality and provably align with exact causal mechanisms under conditions of sufficient heterogeneity. Computationally, FAIR-NN employs a novel Gumbel approximation with decreased temperature and a stochastic gradient descent ascent algorithm. The procedures are demonstrated using simulated and real-data examples.
Paper Structure (86 sections, 35 theorems, 459 equations, 9 figures, 2 tables, 2 algorithms)

This paper contains 86 sections, 35 theorems, 459 equations, 9 figures, 2 tables, 2 algorithms.

Key Result

Theorem 2.1

Assume cond:regularity-fairnn and cond-fairnn-ident hold. Then $\gamma^\star_{\mathtt{NN}}=\sup_{S\subseteq [d]: \mathsf{b}_{\mathtt{NN}}(S)>0} (\mathsf{b}_{\mathtt{NN}}(S)/\bar{\mathsf{d}}_{\mathtt{NN}}(S)) < \infty$, where Consider the estimator that solves eq:def-fair-lse using $\gamma \ge 8\gamma^\star_{\mathtt{NN}}$ and function classes eq:fairnn-function-class with $L, N$ satisfying $NL\le

Figures (9)

  • Figure 1: The running example when $d=3$ and $|\mathcal{E}|=3$. The arrow from node $x$ to $y$ indicates that $x$ affects $y$ directly. The data-generating process of $(X_1,\ldots, X_3, Y)$ in each environment is described by the set of assignments in each panel, and $\varepsilon_1,\ldots,\varepsilon_4$ are independent noises. Compared with the first environment $e=1$, the assignment for $X_3$ perturbs in $e=2$ and the assignment for $X_2$ perturbs in $e=3$, which are marked by red.
  • Figure 2: (a) is an illustration of the two-environment model, the SCMs in the two environments share the same associated graph: $M^{(0)}$ is an observational environment, and $M^{(1)}$ is an intervention environment where some unknown intervention is applied to $(X_4, X_6, X_7)$, where $M^{(0)}$ and $M^{(1)}$ are defined as \ref{['eq:scm-model']}. (b) visualizes $\widetilde{G}$, the associated graph of $\widetilde{M}$ constructed based on $(M^{(0)}, M^{(1)})$ and \ref{['eq:model-2scm']}, which is another plot of the environments in (a). (c) An illustration of \ref{['prop:ident-transfer-learning']} by showing how $S_\star$ therein will change as we see more and more environments: the arrow from $E$ to $X_j$ with color $e$ means $X_j$ is intervened in $e\in \{{\color{myblue} 1}, {\color{myyellow} 2}, {\color{mygreen} 3}, {\color{myred} 4}\}$. For example, $0 \leftrightarrow {\color{mygreen} 3}$ means with interventions in environments 1, 2, and 3, the invariant variable set is ${\color{mygreen} \{1,2,3,7\}}$. Although $X_7$ and is reverse causal and hence related to $Y$, we do not know this based only on the given environments.
  • Figure 3: The visualization of (a) the SCM and (b) the ${\mathrm{sig}(w)}$ during training in one trial for the FAIR-Linear estimator. We use different colors to represent the different relationships with $Y$: blue = parent, red = child, orange = offspring, lightblue = other.
  • Figure 4: The simulation results for linear models with (a) $d=70$ and (b) $d=15$. Both figures depict how the median estimation errors (based on $50$ replications, shown in log scale) for different estimators (marked with different shapes and colors) change when $n$ varies in (a) $\{200, 500, 1000, 2000, 5000\}$ and (b) $\{100, 200, 500, 800, 1000\}$, respectively.
  • Figure 5: Discovery in Real Physical Systems: (a) the unified cause-effect relationship and interventions similar to \ref{['fig:scm-ident']} (b). (b) the average out-of-sample $R^2$ for different estimators using the spider chart: the axis annotated by placeholder variable $Z$ corresponds to the test environment where $Z$ is strongly intervened on. We can see the performance of Oracle-NN and FAIR-NN-RF is almost identical. (c) the average (based on 100 replications) of the worst-case (across 5 environments) of OOS $R^2$ for different methods as a function of $n$. (d) the variable selection rate over $100$ trials for different methods (top panel) and the variable selection rate for FAIR-NN for various $n$ (bottom panel). We use different colors to represent different relationships with $Y$: blue=parent, red=child, orange=neither ancestor nor descendant. (e) the distribution of worst-case OOS $R^2$ (y-axis) for Gumbel-trick optimized FAIR-NN (Gumbel), the follow-up refitted estimator (Refit), and Pooled LS (Pooled) when FAIR-NN selects the wrong variables: the subplots from top to bottom consider the cases of (i) failure in selection consistency (ii) false positive that it falsely selects the child $X_8=\widetilde{V}_3$ (iii) false negative that it does not select the entire ground-truth $(X_1,\ldots, X_5) = (R, G, B, \theta_1, \theta_2)$.
  • ...and 4 more figures

Theorems & Definitions (54)

  • Definition 1: Deep ReLU network class
  • Remark 2.1: Minimal Heterogeneity Condition for Identification
  • Theorem 2.1: Oracle-type Inequality for FAIR-NN Least Squares Estimator
  • Remark 2.2: Interpretation of $\mathsf{b}_{\mathtt{NN}}(S)$ and $\bar{\mathsf{d}}_{\mathtt{NN}}(S)$
  • Remark 2.3: Identification
  • Definition 2: $(\beta,C)$-smooth Function
  • Definition 3: Hierarchical Composition Model $\mathcal{H}_{\mathtt{HCM}}(d, l, \mathcal{O}, C)$
  • Corollary 2.2: Convergence Rate for FAIR-NN
  • Remark 2.4: Error Guarantees for All $n$
  • Remark 2.5: Choice of the Hyper-parameter $\gamma$
  • ...and 44 more