Table of Contents
Fetching ...

Contextual Directed Acyclic Graphs

Ryan Thompson, Edwin V. Bonilla, Robert Kohn

TL;DR

This work addresses learning graphs whose DAG structure varies with context, rather than a single fixed graph. It introduces a neural network that maps contextual features $z$ to a dense graph $\tilde{W}$ and a projection layer enforcing acyclicity via the DAGMA constraint $h_s(W)=0$ and sparsity via an $\ell_1$ ball, producing a context-specific $W(z)$. The authors derive an analytic gradient through the projection layer and prove convergence of the log-det projection, enabling scalable, GPU-accelerated training. Empirical results on synthetic data and a drug-consumption dataset show the contextual DAG accurately recovers context-dependent graphs, outperforming fixed-DAG baselines, and they provide an open-source Julia implementation ContextualDAG for broad use.

Abstract

Estimating the structure of directed acyclic graphs (DAGs) from observational data remains a significant challenge in machine learning. Most research in this area concentrates on learning a single DAG for the entire population. This paper considers an alternative setting where the graph structure varies across individuals based on available "contextual" features. We tackle this contextual DAG problem via a neural network that maps the contextual features to a DAG, represented as a weighted adjacency matrix. The neural network is equipped with a novel projection layer that ensures the output matrices are sparse and satisfy a recently developed characterization of acyclicity. We devise a scalable computational framework for learning contextual DAGs and provide a convergence guarantee and an analytical gradient for backpropagating through the projection layer. Our experiments suggest that the new approach can recover the true context-specific graph where existing approaches fail.

Contextual Directed Acyclic Graphs

TL;DR

This work addresses learning graphs whose DAG structure varies with context, rather than a single fixed graph. It introduces a neural network that maps contextual features to a dense graph and a projection layer enforcing acyclicity via the DAGMA constraint and sparsity via an ball, producing a context-specific . The authors derive an analytic gradient through the projection layer and prove convergence of the log-det projection, enabling scalable, GPU-accelerated training. Empirical results on synthetic data and a drug-consumption dataset show the contextual DAG accurately recovers context-dependent graphs, outperforming fixed-DAG baselines, and they provide an open-source Julia implementation ContextualDAG for broad use.

Abstract

Estimating the structure of directed acyclic graphs (DAGs) from observational data remains a significant challenge in machine learning. Most research in this area concentrates on learning a single DAG for the entire population. This paper considers an alternative setting where the graph structure varies across individuals based on available "contextual" features. We tackle this contextual DAG problem via a neural network that maps the contextual features to a DAG, represented as a weighted adjacency matrix. The neural network is equipped with a novel projection layer that ensures the output matrices are sparse and satisfy a recently developed characterization of acyclicity. We devise a scalable computational framework for learning contextual DAGs and provide a convergence guarantee and an analytical gradient for backpropagating through the projection layer. Our experiments suggest that the new approach can recover the true context-specific graph where existing approaches fail.
Paper Structure (29 sections, 7 theorems, 66 equations, 7 figures, 3 algorithms)

This paper contains 29 sections, 7 theorems, 66 equations, 7 figures, 3 algorithms.

Key Result

Proposition 1

Let $\tilde{W}\in\mathbb{R}^{p\times p}$, $\hat{W}$ be the projection of $\tilde{W}$ onto the $\log\,\det$ level set, and $W^\star$ the projection of $\hat{W}$ onto the $\ell_1$ ball. Then $W^\star$ lies on the intersection of the $\log\,\det$ level set and the $\ell_1$ ball.

Figures (7)

  • Figure 1: Illustration of the contextual DAG. The true graph is a function of the features $z$. The left three graphs correspond to one realization of $z$, and the right three correspond to a second independent realization.
  • Figure 2: Neural network architecture of the contextual DAG. The features $z$ pass through hidden layers to produce a simple directed graph $\tilde{W}$. The projection layer makes $\tilde{W}$ acyclic and sparse, resulting in a DAG $W^\star$.
  • Figure 3: Run times in seconds for 10 epochs of a contextual DAG on an NVIDIA RTX 4090 over 10 synthetic datasets. The number of nodes $p=20$ in the left plot and the sample size $n=1000$ in the right plot. The number of contextual features $m=2$. The solid points are averages and the error bars are one standard errors.
  • Figure 4: Structure recovery performance on varying Erdős-Rényi graphs over 10 synthetic datasets. The number of nodes $p=20$ in the top row and the sample size $n=1000$ in the bottom row. The solid points are averages and the error bars are one standard errors. The sorted DAG (truth) uses the ground truth topological order.
  • Figure 5: The contextual DAG from the drug consumption dataset. The left plot is the graph sparsity as a function of the neuroticism/sensation seeking scores. The other plots are the graphs with the scores at low levels (0.1 quantiles) and high levels (0.9 quantiles). The parameter $\lambda$ is set to attain five edges on average over $z$.
  • ...and 2 more figures

Theorems & Definitions (14)

  • Proposition 1
  • Theorem 1
  • Theorem 2
  • proof
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • ...and 4 more