Graph Classification via Reference Distribution Learning: Theory and Practice
Zixiao Wang, Jicong Fan
TL;DR
This work tackles graph classification by avoiding global pooling and instead treating each graph's node embeddings as a discrete distribution. It introduces Graph Reference Distribution Learning (GRDL), which learns $K$ discriminative references and classifies graphs based on MMD-based similarities between node embeddings and these references, with an end-to-end trainable Gaussian-kernel setup. The authors provide a generalization bound for GRDL, discuss how model choices (GNN depth, reference size, and kernel) influence performance, and demonstrate that GRDL can generalize better than pooling-based GNNs while achieving at least an order of magnitude faster training and inference on large-scale datasets. Empirical results on extensive benchmarks show GRDL achieving state-of-the-art or competitive accuracy with significantly reduced time costs, along with ablations validating key design choices. Overall, GRDL offers a principled, efficient, and scalable approach to graph classification with theoretical guarantees and practical impact for large graph datasets.
Abstract
Graph classification is a challenging problem owing to the difficulty in quantifying the similarity between graphs or representing graphs as vectors, though there have been a few methods using graph kernels or graph neural networks (GNNs). Graph kernels often suffer from computational costs and manual feature engineering, while GNNs commonly utilize global pooling operations, risking the loss of structural or semantic information. This work introduces Graph Reference Distribution Learning (GRDL), an efficient and accurate graph classification method. GRDL treats each graph's latent node embeddings given by GNN layers as a discrete distribution, enabling direct classification without global pooling, based on maximum mean discrepancy to adaptively learned reference distributions. To fully understand this new model (the existing theories do not apply) and guide its configuration (e.g., network architecture, references' sizes, number, and regularization) for practical use, we derive generalization error bounds for GRDL and verify them numerically. More importantly, our theoretical and numerical results both show that GRDL has a stronger generalization ability than GNNs with global pooling operations. Experiments on moderate-scale and large-scale graph datasets show the superiority of GRDL over the state-of-the-art, emphasizing its remarkable efficiency, being at least 10 times faster than leading competitors in both training and inference stages.
