Table of Contents
Fetching ...

Stable Differentiable Causal Discovery

Achille Nazaret, Justin Hong, Elham Azizi, David Blei

TL;DR

The paper tackles learning causal graphs represented as directed acyclic graphs ($\mathrm{DAGs}$) from observational and interventional data. It identifies instability in existing differentiable causal discovery (DCD) methods and introduces Stable Differentiable Causal Discovery (SDCD) that uses a stable spectral acyclicity constraint $h_\rho(A)$ with a two-stage edge-pruning process. The authors prove stability and correctness, and demonstrate through extensive experiments that SDCD converges faster, improves accuracy, and scales to thousands of variables for both observational and interventional data, with code available at the provided repository. This approach enables reliable large-scale causal discovery in data-rich settings, expanding applicability of DCD to real-world problems.

Abstract

Inferring causal relationships as directed acyclic graphs (DAGs) is an important but challenging problem. Differentiable Causal Discovery (DCD) is a promising approach to this problem, framing the search as a continuous optimization. But existing DCD methods are numerically unstable, with poor performance beyond tens of variables. In this paper, we propose Stable Differentiable Causal Discovery (SDCD), a new method that improves previous DCD methods in two ways: (1) It employs an alternative constraint for acyclicity; this constraint is more stable, both theoretically and empirically, and fast to compute. (2) It uses a training procedure tailored for sparse causal graphs, which are common in real-world scenarios. We first derive SDCD and prove its stability and correctness. We then evaluate it with both observational and interventional data and on both small-scale and large-scale settings. We find that SDCD outperforms existing methods in both convergence speed and accuracy and can scale to thousands of variables. We provide code at https://github.com/azizilab/sdcd.

Stable Differentiable Causal Discovery

TL;DR

The paper tackles learning causal graphs represented as directed acyclic graphs () from observational and interventional data. It identifies instability in existing differentiable causal discovery (DCD) methods and introduces Stable Differentiable Causal Discovery (SDCD) that uses a stable spectral acyclicity constraint with a two-stage edge-pruning process. The authors prove stability and correctness, and demonstrate through extensive experiments that SDCD converges faster, improves accuracy, and scales to thousands of variables for both observational and interventional data, with code available at the provided repository. This approach enables reliable large-scale causal discovery in data-rich settings, expanding applicability of DCD to real-world problems.

Abstract

Inferring causal relationships as directed acyclic graphs (DAGs) is an important but challenging problem. Differentiable Causal Discovery (DCD) is a promising approach to this problem, framing the search as a continuous optimization. But existing DCD methods are numerically unstable, with poor performance beyond tens of variables. In this paper, we propose Stable Differentiable Causal Discovery (SDCD), a new method that improves previous DCD methods in two ways: (1) It employs an alternative constraint for acyclicity; this constraint is more stable, both theoretically and empirically, and fast to compute. (2) It uses a training procedure tailored for sparse causal graphs, which are common in real-world scenarios. We first derive SDCD and prove its stability and correctness. We then evaluate it with both observational and interventional data and on both small-scale and large-scale settings. We find that SDCD outperforms existing methods in both convergence speed and accuracy and can scale to thousands of variables. We provide code at https://github.com/azizilab/sdcd.
Paper Structure (38 sections, 12 theorems, 38 equations, 14 figures, 8 tables, 2 algorithms)

This paper contains 38 sections, 12 theorems, 38 equations, 14 figures, 8 tables, 2 algorithms.

Key Result

Theorem 3.2

For any sequence $(a_k)_{k\in \mathbb{N}^*} \in \mathbb{R}_{\geq 0}^{\mathbb{N}^*}$, if we have $a_k > 0$ for all $k \in \llbracket 1,d \rrbracket$, then, for any matrix $A \in \mathbb{R}_{\geq0}^{d\times d}$, we have We say that $h_a$ is a PST constraint.

Figures (14)

  • Figure 1: Visual representation of the SDCD method.
  • Figure 2: Constraint behaviors when evaluated on uniform random matrices in $[0, \epsilon]^{d \times d}$ (dashed) or a cycle of length $d/2$ with weight $\varepsilon$ (solid). The y-axis shows the constraint's value, the x-axis is (left) the weights' scale $\varepsilon$ (right) the number of variables $d$. Only the proposed $h_\rho$ (orange) remains stable; others vanish to zero exponentially or escalate to infinity (as soon as $d>10$). The vertical dotted lines indicate the constraint escaped its domain of definition. All these failures were encountered during DCD experiments.
  • Figure 3: SHD across simulations on observational data with increasing numbers of variables $d$. SDCD achieves the best SHDs. It is the only method scaling above 200 variables with nontrivial SHD. Missing data points imply the method failed to run. Error bars indicate std on 30 random datasets for $d \leq 50$ and five for $d>50$ (175 total). Lower is better.
  • Figure 4: SHD across simulations with an increasing proportion of variables intervened on, varying the total number of variables $d$ (columns) and average edges per variable $s$ (rows). SDCD is the only method to consistently improve with interventional data and has the best SHDs for sparse graphs (edge density $\delta \leq 45\%$). Each boxplot over 5 random datasets (45 datasets total).
  • Figure 5: The effect of constraints on the learned graph throughout training. The training with penalty $h_\rho+$ (dashed purple, exactly $h_\rho$ with a hard mask on the diagonal as to prevent self-loops, as implemented in SDCD) converges the fastest toward a DAG. (left) training with $h$ as a regularization penalty. (right) training with $h$ as an augmented Lagrangian constraint. Threshold to DAG is the smallest $\eta$ at which all edges with weight $> \eta$ form a DAG.
  • ...and 9 more figures

Theorems & Definitions (25)

  • Definition 3.1: The Power Series Trace Family
  • Theorem 3.2: PST constraint
  • Definition 3.3
  • Theorem 3.4: PST instability
  • Definition 3.5: Spectral radius
  • Theorem 3.6: cvetkovic1980lee2019scaling
  • Theorem 3.7
  • Remark 3.8
  • Remark 4.1
  • Theorem 4.2
  • ...and 15 more