Contextual Directed Acyclic Graphs
Ryan Thompson, Edwin V. Bonilla, Robert Kohn
TL;DR
This work addresses learning graphs whose DAG structure varies with context, rather than a single fixed graph. It introduces a neural network that maps contextual features $z$ to a dense graph $\tilde{W}$ and a projection layer enforcing acyclicity via the DAGMA constraint $h_s(W)=0$ and sparsity via an $\ell_1$ ball, producing a context-specific $W(z)$. The authors derive an analytic gradient through the projection layer and prove convergence of the log-det projection, enabling scalable, GPU-accelerated training. Empirical results on synthetic data and a drug-consumption dataset show the contextual DAG accurately recovers context-dependent graphs, outperforming fixed-DAG baselines, and they provide an open-source Julia implementation ContextualDAG for broad use.
Abstract
Estimating the structure of directed acyclic graphs (DAGs) from observational data remains a significant challenge in machine learning. Most research in this area concentrates on learning a single DAG for the entire population. This paper considers an alternative setting where the graph structure varies across individuals based on available "contextual" features. We tackle this contextual DAG problem via a neural network that maps the contextual features to a DAG, represented as a weighted adjacency matrix. The neural network is equipped with a novel projection layer that ensures the output matrices are sparse and satisfy a recently developed characterization of acyclicity. We devise a scalable computational framework for learning contextual DAGs and provide a convergence guarantee and an analytical gradient for backpropagating through the projection layer. Our experiments suggest that the new approach can recover the true context-specific graph where existing approaches fail.
