Graph Structure Inference with BAM: Introducing the Bilinear Attention Mechanism
Philipp Froehlich, Heinz Koeppl
TL;DR
Graph structure inference from observational data is framed as a supervised learning problem using covariance information processed on the SPD manifold. The Bilinear Attention Mechanism (BAM) combines channel embeddings, observational self-attention, SPD-based bilinear attention, and a Log-Eig mapping, with training data generated from SEMs using random Chebyshev polynomials to cover diverse dependencies. BAM delivers robust undirected graph recovery and competitive CPDAG estimation by first identifying skeleton and moralized edges and then orienting edges via a dedicated CPDAG network informed by Meek rules. The approach demonstrates strong generalization to nonlinear dependencies, offers computational efficiency relative to unsupervised methods, and opens new avenues for SPD-manifold optimization in graph learning.
Abstract
In statistics and machine learning, detecting dependencies in datasets is a central challenge. We propose a novel neural network model for supervised graph structure learning, i.e., the process of learning a mapping between observational data and their underlying dependence structure. The model is trained with variably shaped and coupled simulated input data and requires only a single forward pass through the trained network for inference. By leveraging structural equation models and employing randomly generated multivariate Chebyshev polynomials for the simulation of training data, our method demonstrates robust generalizability across both linear and various types of non-linear dependencies. We introduce a novel bilinear attention mechanism (BAM) for explicit processing of dependency information, which operates on the level of covariance matrices of transformed data and respects the geometry of the manifold of symmetric positive definite matrices. Empirical evaluation demonstrates the robustness of our method in detecting a wide range of dependencies, excelling in undirected graph estimation and proving competitive in completed partially directed acyclic graph estimation through a novel two-step approach.
