Table of Contents
Fetching ...

Mirage: Model-Agnostic Graph Distillation for Graph Classification

Mridul Gupta, Sahil Manchanda, Hariprasad Kodamana, Sayan Ranu

TL;DR

Graph neural networks are powerful but demand extensive data and computation. Mirage introduces a model-agnostic, unsupervised graph distillation framework that compresses training data by mining frequently co-occurring computation trees derived from the MP-GNN computation framework, enabling a single synthetic dataset to train across architectures. Across six real-world datasets and multiple GNNs, Mirage achieves higher generalization accuracy, about 4–5× data compression, and roughly 150× faster distillation compared to gradient-based baselines, while remaining CPU-bound. The approach leverages the skewed distribution of computation trees and frequent pattern mining to preserve informative patterns, offering a practical, robust, and scalable solution for distilling graph data, albeit with limitations in generalizing to unseen tasks and non-skewed distributions in some domains.

Abstract

GNNs, like other deep learning models, are data and computation hungry. There is a pressing need to scale training of GNNs on large datasets to enable their usage on low-resource environments. Graph distillation is an effort in that direction with the aim to construct a smaller synthetic training set from the original training data without significantly compromising model performance. While initial efforts are promising, this work is motivated by two key observations: (1) Existing graph distillation algorithms themselves rely on training with the full dataset, which undermines the very premise of graph distillation. (2) The distillation process is specific to the target GNN architecture and hyper-parameters and thus not robust to changes in the modeling pipeline. We circumvent these limitations by designing a distillation algorithm called Mirage for graph classification. Mirage is built on the insight that a message-passing GNN decomposes the input graph into a multiset of computation trees. Furthermore, the frequency distribution of computation trees is often skewed in nature, enabling us to condense this data into a concise distilled summary. By compressing the computation data itself, as opposed to emulating gradient flows on the original training set-a prevalent approach to date-Mirage transforms into an unsupervised and architecture-agnostic distillation algorithm. Extensive benchmarking on real-world datasets underscores Mirage's superiority, showcasing enhanced generalization accuracy, data compression, and distillation efficiency when compared to state-of-the-art baselines.

Mirage: Model-Agnostic Graph Distillation for Graph Classification

TL;DR

Graph neural networks are powerful but demand extensive data and computation. Mirage introduces a model-agnostic, unsupervised graph distillation framework that compresses training data by mining frequently co-occurring computation trees derived from the MP-GNN computation framework, enabling a single synthetic dataset to train across architectures. Across six real-world datasets and multiple GNNs, Mirage achieves higher generalization accuracy, about 4–5× data compression, and roughly 150× faster distillation compared to gradient-based baselines, while remaining CPU-bound. The approach leverages the skewed distribution of computation trees and frequent pattern mining to preserve informative patterns, offering a practical, robust, and scalable solution for distilling graph data, albeit with limitations in generalizing to unseen tasks and non-skewed distributions in some domains.

Abstract

GNNs, like other deep learning models, are data and computation hungry. There is a pressing need to scale training of GNNs on large datasets to enable their usage on low-resource environments. Graph distillation is an effort in that direction with the aim to construct a smaller synthetic training set from the original training data without significantly compromising model performance. While initial efforts are promising, this work is motivated by two key observations: (1) Existing graph distillation algorithms themselves rely on training with the full dataset, which undermines the very premise of graph distillation. (2) The distillation process is specific to the target GNN architecture and hyper-parameters and thus not robust to changes in the modeling pipeline. We circumvent these limitations by designing a distillation algorithm called Mirage for graph classification. Mirage is built on the insight that a message-passing GNN decomposes the input graph into a multiset of computation trees. Furthermore, the frequency distribution of computation trees is often skewed in nature, enabling us to condense this data into a concise distilled summary. By compressing the computation data itself, as opposed to emulating gradient flows on the original training set-a prevalent approach to date-Mirage transforms into an unsupervised and architecture-agnostic distillation algorithm. Extensive benchmarking on real-world datasets underscores Mirage's superiority, showcasing enhanced generalization accuracy, data compression, and distillation efficiency when compared to state-of-the-art baselines.
Paper Structure (28 sections, 7 equations, 12 figures, 9 tables, 2 algorithms)

This paper contains 28 sections, 7 equations, 12 figures, 9 tables, 2 algorithms.

Figures (12)

  • Figure 1: Pipeline of Mirage.
  • Figure 2: Frequency distribution of computation trees across datasets. The "frequency" of a computation tree denotes the number of occurrences of that specific tree across all graphs in a dataset. The normalized frequency of a tree is computed by dividing its frequency with the total number of graphs in at dataset and thus falls in the range $[0,1]$. The $x$-axis of the plot depicts the normalized frequency counts observed in a dataset, while the $y$-axis represents the percentage of computation trees corresponding to each frequency count. Both $x$ and $y$ axes are in log scale. The distribution is highly skewed characterized by a dominance of trees with low frequency counts, while a small subset of trees exhibiting higher frequencies. For example, in ogbg-molhiv, the most frequent tree alone has normalized frequency of $0.32$.
  • Figure 3: In (a) we show the construction of the computation tree for $v_0\in\mathcal{G}\xspace_1$. In (b), we present $\mathcal{G}\xspace_2$, which has an isomorphic $2$-hop computational tree for $u_2$ despite its neighborhood being non-isomorphic to $v_0$. We assume the node feature vectors to be an one-hot encoding of the node colors.
  • Figure 4: (a) Distillation times for the different methods. (b) Distillation time as a function of number of hops on ogbg-molbbbp dataset.
  • Figure 5: (a) For this experiment the model weights are extracted after each epoch. Then, the model weights are loaded from the epoch weights and kept fixed for the following procedure. The dataset condensed using Mirage and the full dataset are then passed through the model. The difference between the losses is plotted. The difference between the losses approaches 0. Note that the model was trained on the full dataset. (b) Training loss vs epochs on ogbg-molhiv(GCN). Results on more datasets can be found in Appendix \ref{['sec:train_eff']}.
  • ...and 7 more figures

Theorems & Definitions (6)

  • Definition 1: Graph
  • Definition 2: Graph Isomorphism
  • Definition 3: Computation Tree
  • proof
  • proof
  • Definition 4: Canonical label