Table of Contents
Fetching ...

Transformers Provably Learn Directed Acyclic Graphs via Kernel-Guided Mutual Information

Yuan Cheng, Yu Huang, Zhe Xiong, Yingbin Liang, Vincent Y. F. Tan

TL;DR

This work advances theory and practice for learning Directed Acyclic Graphs from sequence data by introducing kernel-guided mutual information (KG-MI) and a multi-head attention framework where each head uses a distinct marginal transition kernel. It proves that gradient ascent on the KG-MI–based objective converges to a global optimum in polynomial time and characterizes attention patterns at convergence, showing that KL-KG-MI recovers the true DAG adjacency. The results extend prior tree-focused analyses to general DAGs and demonstrate practical efficiency and accuracy on synthetic graphs, including exact recovery under KL divergence. Overall, the method provides a principled, provable approach to multi-parent DAG structure discovery using transformers, with guidance on divergence choice to optimize convergence speed and accuracy.

Abstract

Uncovering hidden graph structures underlying real-world data is a critical challenge with broad applications across scientific domains. Recently, transformer-based models leveraging the attention mechanism have demonstrated strong empirical success in capturing complex dependencies within graphs. However, the theoretical understanding of their training dynamics has been limited to tree-like graphs, where each node depends on a single parent. Extending provable guarantees to more general directed acyclic graphs (DAGs) -- which involve multiple parents per node -- remains challenging, primarily due to the difficulty in designing training objectives that enable different attention heads to separately learn multiple different parent relationships. In this work, we address this problem by introducing a novel information-theoretic metric: the kernel-guided mutual information (KG-MI), based on the $f$-divergence. Our objective combines KG-MI with a multi-head attention framework, where each head is associated with a distinct marginal transition kernel to model diverse parent-child dependencies effectively. We prove that, given sequences generated by a $K$-parent DAG, training a single-layer, multi-head transformer via gradient ascent converges to the global optimum in polynomial time. Furthermore, we characterize the attention score patterns at convergence. In addition, when particularizing the $f$-divergence to the KL divergence, the learned attention scores accurately reflect the ground-truth adjacency matrix, thereby provably recovering the underlying graph structure. Experimental results validate our theoretical findings.

Transformers Provably Learn Directed Acyclic Graphs via Kernel-Guided Mutual Information

TL;DR

This work advances theory and practice for learning Directed Acyclic Graphs from sequence data by introducing kernel-guided mutual information (KG-MI) and a multi-head attention framework where each head uses a distinct marginal transition kernel. It proves that gradient ascent on the KG-MI–based objective converges to a global optimum in polynomial time and characterizes attention patterns at convergence, showing that KL-KG-MI recovers the true DAG adjacency. The results extend prior tree-focused analyses to general DAGs and demonstrate practical efficiency and accuracy on synthetic graphs, including exact recovery under KL divergence. Overall, the method provides a principled, provable approach to multi-parent DAG structure discovery using transformers, with guidance on divergence choice to optimize convergence speed and accuracy.

Abstract

Uncovering hidden graph structures underlying real-world data is a critical challenge with broad applications across scientific domains. Recently, transformer-based models leveraging the attention mechanism have demonstrated strong empirical success in capturing complex dependencies within graphs. However, the theoretical understanding of their training dynamics has been limited to tree-like graphs, where each node depends on a single parent. Extending provable guarantees to more general directed acyclic graphs (DAGs) -- which involve multiple parents per node -- remains challenging, primarily due to the difficulty in designing training objectives that enable different attention heads to separately learn multiple different parent relationships. In this work, we address this problem by introducing a novel information-theoretic metric: the kernel-guided mutual information (KG-MI), based on the -divergence. Our objective combines KG-MI with a multi-head attention framework, where each head is associated with a distinct marginal transition kernel to model diverse parent-child dependencies effectively. We prove that, given sequences generated by a -parent DAG, training a single-layer, multi-head transformer via gradient ascent converges to the global optimum in polynomial time. Furthermore, we characterize the attention score patterns at convergence. In addition, when particularizing the -divergence to the KL divergence, the learned attention scores accurately reflect the ground-truth adjacency matrix, thereby provably recovering the underlying graph structure. Experimental results validate our theoretical findings.

Paper Structure

This paper contains 31 sections, 15 theorems, 86 equations, 8 figures, 1 table, 1 algorithm.

Key Result

Theorem 4.1

Suppose Assumptions Assp: MC-non sym and assp: est of F hold, and every estimate for $f$-KG-MI chosen in alg:meta is bounded above by $I_{\max}$, i.e., $\max_{i,j \in [T]}\max_{\ell\in [K]} \hat{I}^\ell_f(S_i; S_j)\leq I_{\max}$. Let $\epsilon<\frac{1}{4}\Delta$, and for any attention error $\vareps Then, for any iteration $t\geq \tau_*$, the output of alg:meta, $\theta(t)$, satisfies $L^*-L(\thet

Figures (8)

  • Figure 1: The process of using a transformer to learn the graph structure proceeds as follows. Random sequences generated by a DAG (top left) are fed into the transformer (bottom left), which produces multi-head attention scores (bottom right). These scores are then used to estimate the adjacency matrix of the DAG (top right). The DAG illustrated in the top left consists of five nodes with the edge set $E = \{(1,3), (1,4), (2,3), (2,5), (3,4), (4,5)\}$. It features two root nodes, labeled 1 and 2, and three non-root nodes labeled 3, 4, and 5, each with an in-degree of 2. The adjacency matrix at the top right highlights parent-child relationships, with positions marked in dark red and purple. The attention scores from the transformer's multi-head layers are visualized in the bottom right, where darker colors indicate higher attention values. Note that using a standard symmetric multi-head architecture results in head collapse, leading to erroneous learning. Conversely, our KG-MI objective effectively mitigates this issue, enabling successful graph structure learning.
  • Figure 2: Heatmaps of the true adjacency matrices of the DAG $\mathcal{G}$ (first column) are shown alongside the attention patterns learned by two heads using different mutual information measures. The top row displays results obtained with KL mutual information (second and third columns), while the bottom row presents results with the $\chi^2$-mutual information (second and third columns). In this experiment, the data is sampled from a graph with 10 nodes, with the first two nodes being roots. To visualize the learned multi-head attention scores, different colors (red and blue) are used to represent the attention patterns of each head. When trained with the naive objective from Eq. \ref{['Eq: stra obj']}, both KL and $\chi^2$-mutual information lead to head collapse: the two heads produce identical heatmaps and only recover a single parent node in $\mathcal{G}$.
  • Figure 3: The structure of the meta-graph with 10 nodes. It includes two root nodes, 1 and 2, and eight non-root nodes, 3-10, each with an in-degree of 2. The edge set $E=\{(1,3),(1,4),(1,10),(2,3),(2,5),(2,6),(2,7),(3,4),(3,8),(4,5),(4,7),(5,6),$$(6,9),(7,8),(7,10),(8,9)\}$.
  • Figure 4: The attention scores of the trained transformer using KG-MI. We compare the heatmaps between the adjacency matrix of graph $\mathcal{G}$ (left), and the trained attention scores trained with true KL-KG-MI (middle) and with estimated $\chi^2$-KG-MI (right). To present the learned multi-head attention scores, for non-root nodes (nodes $3$ to $10$ above), we add the attention score of both heads $1,2$ in our transformer into one heatmap and change the highest score of head $2$ to be negative to distinguish it from head $1$. In particular, the first two nodes of the graph are root nodes. As a result, for node $1$, heads $1$ and $2$ will not attend to anything and for node $2$, heads $1$ and $2$ will attend to node $1$. In conclusion, both heatmaps converge to the true adjacency matrix of $\mathcal{G}$.
  • Figure 5: Convergence rates w.r.t. the length of the random sequences $T$ (subfigure a) and information gap $\Delta$ (subfigure b). Comprehensive study of different $f$-KG-MI 's for the sample random sequences (subfigure c).
  • ...and 3 more figures

Theorems & Definitions (34)

  • Definition 2.1: Causal Self-Attention Head
  • Definition 3.1
  • Definition 3.2: Kernel-Guided Mutual Information
  • Theorem 4.1: Convergence of Objective Function
  • Theorem 4.2: Attention Concentration
  • Theorem 4.3
  • Lemma B.1: Existence of a Stationary Distribution and Identical Marginals
  • proof
  • Lemma B.2: Property of Stationary Distribution
  • proof
  • ...and 24 more