Table of Contents
Fetching ...

GraphFM: A Scalable Framework for Multi-Graph Pretraining

Divyansha Lachi, Mehdi Azabou, Vinam Arora, Eva Dyer

TL;DR

GraphFM introduces a scalable, multi-graph pretraining framework that learns a generalist model across diverse graph domains by compressing heterogeneous graphs into a fixed latent space with a Perceiver encoder. It employs latent tokens as a shared vocabulary, a multi-task node decoder for reading node-level information, and a DistributedSSSampler to balance GPU utilization, enabling training on 152 datasets with millions of nodes and edges. The approach yields strong cross-domain transfer, matches or surpasses specialist models on both homophilic and heterophilic graphs, and reveals scaling laws showing benefits from larger data and models. This work demonstrates that a single, foundation-style graph model can efficiently learn across domains, reducing the need for dataset-specific architectures and hyperparameter tuning, with practical implications for rapid deployment in new graph-centric tasks.

Abstract

Graph neural networks are typically trained on individual datasets, often requiring highly specialized models and extensive hyperparameter tuning. This dataset-specific approach arises because each graph dataset often has unique node features and diverse connectivity structures, making it difficult to build a generalist model. To address these challenges, we introduce a scalable multi-graph multi-task pretraining approach specifically tailored for node classification tasks across diverse graph datasets from different domains. Our method, Graph Foundation Model (GraphFM), leverages a Perceiver-based encoder that employs learned latent tokens to compress domain-specific features into a common latent space. This approach enhances the model's ability to generalize across different graphs and allows for scaling across diverse data. We demonstrate the efficacy of our approach by training a model on 152 different graph datasets comprising over 7.4 million nodes and 189 million edges, establishing the first set of scaling laws for multi-graph pretraining on datasets spanning many domains (e.g., molecules, citation and product graphs). Our results show that pretraining on a diverse array of real and synthetic graphs improves the model's adaptability and stability, while performing competitively with state-of-the-art specialist models. This work illustrates that multi-graph pretraining can significantly reduce the burden imposed by the current graph training paradigm, unlocking new capabilities for the field of graph neural networks by creating a single generalist model that performs competitively across a wide range of datasets and tasks.

GraphFM: A Scalable Framework for Multi-Graph Pretraining

TL;DR

GraphFM introduces a scalable, multi-graph pretraining framework that learns a generalist model across diverse graph domains by compressing heterogeneous graphs into a fixed latent space with a Perceiver encoder. It employs latent tokens as a shared vocabulary, a multi-task node decoder for reading node-level information, and a DistributedSSSampler to balance GPU utilization, enabling training on 152 datasets with millions of nodes and edges. The approach yields strong cross-domain transfer, matches or surpasses specialist models on both homophilic and heterophilic graphs, and reveals scaling laws showing benefits from larger data and models. This work demonstrates that a single, foundation-style graph model can efficiently learn across domains, reducing the need for dataset-specific architectures and hyperparameter tuning, with practical implications for rapid deployment in new graph-centric tasks.

Abstract

Graph neural networks are typically trained on individual datasets, often requiring highly specialized models and extensive hyperparameter tuning. This dataset-specific approach arises because each graph dataset often has unique node features and diverse connectivity structures, making it difficult to build a generalist model. To address these challenges, we introduce a scalable multi-graph multi-task pretraining approach specifically tailored for node classification tasks across diverse graph datasets from different domains. Our method, Graph Foundation Model (GraphFM), leverages a Perceiver-based encoder that employs learned latent tokens to compress domain-specific features into a common latent space. This approach enhances the model's ability to generalize across different graphs and allows for scaling across diverse data. We demonstrate the efficacy of our approach by training a model on 152 different graph datasets comprising over 7.4 million nodes and 189 million edges, establishing the first set of scaling laws for multi-graph pretraining on datasets spanning many domains (e.g., molecules, citation and product graphs). Our results show that pretraining on a diverse array of real and synthetic graphs improves the model's adaptability and stability, while performing competitively with state-of-the-art specialist models. This work illustrates that multi-graph pretraining can significantly reduce the burden imposed by the current graph training paradigm, unlocking new capabilities for the field of graph neural networks by creating a single generalist model that performs competitively across a wide range of datasets and tasks.
Paper Structure (32 sections, 3 equations, 7 figures, 6 tables, 1 algorithm)

This paper contains 32 sections, 3 equations, 7 figures, 6 tables, 1 algorithm.

Figures (7)

  • Figure 1: Overview of GraphFM architecture and multi-graph training approach: The input node-level tokens are passed into a cross-attention layer, and then through multiple self-attention layers. We decode node-level properties by creating a spatial sequence with features from a query node and a subset of its neighbors, which is processed through self-attention layers before being processed by a node decoder that uses self attention across the node and its neighbors.
  • Figure 2: Characteristics of graph datasets used to train GraphFM: From left to right, we compute the histograms of the homophily ratio, average degree, number of nodes and number of edges of all 152 graphs used during training.
  • Figure 3: Multi-GPU utilization: GPU memory utilization during distributed training when using the default batch sampler with 8 GPUs (right) vs. our DistributedSSSampler for N=4 (middle) and N=64 (right) GPUs. The total batch size is $N\times b$.
  • Figure 4: Scaling analysis: Average accuracy across OOD datasets for different model sizes (389K, 18M, 75M) and different amounts of tokens (200K, 2M, 7.3M) seen during the pre-training phase.
  • Figure 5: Hyperparameter sensitivity and learning curves:A: The performance of GCN and GraphFM for 100 different random hyperparameters on Coauthor-CS and Chameleon. The star denotes the model with the optimal hyperparameter, and the color indicates a normalized L-2 distance between the hyperparameters of each model and this optimal solution. B: Learning curves for 100 random GCN models and GraphFM finetuning for Coauthor-CS and Chameleon. Refer to Figure \ref{['fig:app_finetuning']} for additional datasets.
  • ...and 2 more figures