Table of Contents
Fetching ...

Sparse Decomposition of Graph Neural Networks

Yaochen Hu, Mai Zeng, Ge Zhang, Pavel Rumiantsev, Liheng Ma, Yingxue Zhang, Mark Coates

TL;DR

SDGNN introduces a sparse decomposition to approximate target GNN embeddings with low online inference cost. By learning a feature transform $\phi(\cdot;\mathbf{W})$ and node-specific sparse weights $\boldsymbol{\theta}_z$, it represents each node as $\hat{g}(z,\mathbf{X}|\mathcal{G}) = {\boldsymbol{\theta}}_z^\top \phi(\mathbf{X};\mathbf{W})$, enabling per-node complexity $O(\bar{d}L)$. The optimization alternates between a Lasso-based phase for $\boldsymbol{\theta}_z$ and a gradient-based phase for $\mathbf{W}$, with scalable strategies like mini-batching and candidate-set narrowing to handle large graphs. Empirical results on seven node-classification datasets and two spatio-temporal forecasting tasks show SDGNN closely matches or surpasses target GNN performance while offering substantially reduced inference times, making online prediction with dynamic node features feasible. This work provides a practical framework for deploying GNNs in real-time settings by balancing expressive power and inference efficiency.

Abstract

Graph Neural Networks (GNN) exhibit superior performance in graph representation learning, but their inference cost can be high, due to an aggregation operation that can require a memory fetch for a very large number of nodes. This inference cost is the major obstacle to deploying GNN models with \emph{online prediction} to reflect the potentially dynamic node features. To address this, we propose an approach to reduce the number of nodes that are included during aggregation. We achieve this through a sparse decomposition, learning to approximate node representations using a weighted sum of linearly transformed features of a carefully selected subset of nodes within the extended neighbourhood. The approach achieves linear complexity with respect to the average node degree and the number of layers in the graph neural network. We introduce an algorithm to compute the optimal parameters for the sparse decomposition, ensuring an accurate approximation of the original GNN model, and present effective strategies to reduce the training time and improve the learning process. We demonstrate via extensive experiments that our method outperforms other baselines designed for inference speedup, achieving significant accuracy gains with comparable inference times for both node classification and spatio-temporal forecasting tasks.

Sparse Decomposition of Graph Neural Networks

TL;DR

SDGNN introduces a sparse decomposition to approximate target GNN embeddings with low online inference cost. By learning a feature transform and node-specific sparse weights , it represents each node as , enabling per-node complexity . The optimization alternates between a Lasso-based phase for and a gradient-based phase for , with scalable strategies like mini-batching and candidate-set narrowing to handle large graphs. Empirical results on seven node-classification datasets and two spatio-temporal forecasting tasks show SDGNN closely matches or surpasses target GNN performance while offering substantially reduced inference times, making online prediction with dynamic node features feasible. This work provides a practical framework for deploying GNNs in real-time settings by balancing expressive power and inference efficiency.

Abstract

Graph Neural Networks (GNN) exhibit superior performance in graph representation learning, but their inference cost can be high, due to an aggregation operation that can require a memory fetch for a very large number of nodes. This inference cost is the major obstacle to deploying GNN models with \emph{online prediction} to reflect the potentially dynamic node features. To address this, we propose an approach to reduce the number of nodes that are included during aggregation. We achieve this through a sparse decomposition, learning to approximate node representations using a weighted sum of linearly transformed features of a carefully selected subset of nodes within the extended neighbourhood. The approach achieves linear complexity with respect to the average node degree and the number of layers in the graph neural network. We introduce an algorithm to compute the optimal parameters for the sparse decomposition, ensuring an accurate approximation of the original GNN model, and present effective strategies to reduce the training time and improve the learning process. We demonstrate via extensive experiments that our method outperforms other baselines designed for inference speedup, achieving significant accuracy gains with comparable inference times for both node classification and spatio-temporal forecasting tasks.

Paper Structure

This paper contains 46 sections, 12 equations, 9 figures, 13 tables, 1 algorithm.

Figures (9)

  • Figure 1: The pipeline overview for SDGNN framework (bottom pipeline). To compute GNN embedding efficiently, we use a transformation function to adapt node features and introduce sparse vectors associated with each node to gather information from critical neighbours. The parameters in the transformation function and the sparse vectors are determined by optimization to approximate the target GNN embeddings.
  • Figure 2: Accuracy v.s. mean inference wall-clock time over 10,000 randomly sampled nodes on the Products test set. GL$\rightarrow$GLNN, NOS$\rightarrow$NOSMOG, Rev$\rightarrow$RevGNN. w4, w8: student size enlarged 4, 8 times. L1, L2, and L3 denote GNN layers. N20: a neighbour sampling size of 20.
  • Figure 3: (a) MAPE v.s. mean receptive field size on PeMS08 dataset. L1, L2 denote one and two GCN layers, respectively. GLNN and MOSMOG are learned with GRU-GCN-L2. Targets of SDGNN are indicated in brackets. Error bar shows standard deviation. (b) Scatter plot of MAE reduction at each node (compared to GRU) for GRU-GCN (y-axis) with SDGNN (x-axis). Dashed line is $y=x$. The improvements are highly correlated, with a slight bias in favour of the more computationally expensive GRU-GCN.
  • Figure 4: Convergence curve under various candidate sets for Pubmed dataset with SAGE model.
  • Figure 5: The empirical CDFs of the number of receptive nodes and the empirical CDFs of the row normalized ${\mathbf{\Theta}}$ for SDGNN, SGC and PPRGo for ArXiv (SAGE, DRGAT) and Ogbn-Products (SAGE, RevGNN-112).
  • ...and 4 more figures