Transformers Provably Learn Directed Acyclic Graphs via Kernel-Guided Mutual Information
Yuan Cheng, Yu Huang, Zhe Xiong, Yingbin Liang, Vincent Y. F. Tan
TL;DR
This work advances theory and practice for learning Directed Acyclic Graphs from sequence data by introducing kernel-guided mutual information (KG-MI) and a multi-head attention framework where each head uses a distinct marginal transition kernel. It proves that gradient ascent on the KG-MI–based objective converges to a global optimum in polynomial time and characterizes attention patterns at convergence, showing that KL-KG-MI recovers the true DAG adjacency. The results extend prior tree-focused analyses to general DAGs and demonstrate practical efficiency and accuracy on synthetic graphs, including exact recovery under KL divergence. Overall, the method provides a principled, provable approach to multi-parent DAG structure discovery using transformers, with guidance on divergence choice to optimize convergence speed and accuracy.
Abstract
Uncovering hidden graph structures underlying real-world data is a critical challenge with broad applications across scientific domains. Recently, transformer-based models leveraging the attention mechanism have demonstrated strong empirical success in capturing complex dependencies within graphs. However, the theoretical understanding of their training dynamics has been limited to tree-like graphs, where each node depends on a single parent. Extending provable guarantees to more general directed acyclic graphs (DAGs) -- which involve multiple parents per node -- remains challenging, primarily due to the difficulty in designing training objectives that enable different attention heads to separately learn multiple different parent relationships. In this work, we address this problem by introducing a novel information-theoretic metric: the kernel-guided mutual information (KG-MI), based on the $f$-divergence. Our objective combines KG-MI with a multi-head attention framework, where each head is associated with a distinct marginal transition kernel to model diverse parent-child dependencies effectively. We prove that, given sequences generated by a $K$-parent DAG, training a single-layer, multi-head transformer via gradient ascent converges to the global optimum in polynomial time. Furthermore, we characterize the attention score patterns at convergence. In addition, when particularizing the $f$-divergence to the KL divergence, the learned attention scores accurately reflect the ground-truth adjacency matrix, thereby provably recovering the underlying graph structure. Experimental results validate our theoretical findings.
