Table of Contents
Fetching ...

Differentiable Cluster Graph Neural Network

Yanfei Dong, Mohammed Haroon Dupty, Lambert Deng, Zhuanghua Liu, Yong Liang Goh, Wee Sun Lee

TL;DR

DC-GNN tackles the dual challenges of long-range information propagation and heterophily by injecting a differentiable clustering inductive bias into GNN message passing. It achieves this by augmenting the original graph with a bipartite set of cluster-nodes (global and local) and optimizing an OT-based clustering objective \mathcal{L}_{\rm cluster}\u001f through a differentiable block-coordinate descent (DC-MsgPassing) that alternates between Sinkhorn-based assignment updates and closed-form embedding updates. The approach yields two key benefits: local cluster-nodes preserve graph structure and improve aggregation in heterophilous neighborhoods, while global cluster-nodes enable distant information transfer, collectively reducing oversquashing and improving performance on both heterophilous and homophilous datasets. Empirical results on 14 datasets demonstrate state-of-the-art accuracy across heterophilous graphs and strong results on homophilous graphs, with clear ablations showing the value of the clustering terms and regularizers for robust learning.

Abstract

Graph Neural Networks often struggle with long-range information propagation and in the presence of heterophilous neighborhoods. We address both challenges with a unified framework that incorporates a clustering inductive bias into the message passing mechanism, using additional cluster-nodes. Central to our approach is the formulation of an optimal transport based implicit clustering objective function. However, the algorithm for solving the implicit objective function needs to be differentiable to enable end-to-end learning of the GNN. To facilitate this, we adopt an entropy regularized objective function and propose an iterative optimization process, alternating between solving for the cluster assignments and updating the node/cluster-node embeddings. Notably, our derived closed-form optimization steps are themselves simple yet elegant message passing steps operating seamlessly on a bipartite graph of nodes and cluster-nodes. Our clustering-based approach can effectively capture both local and global information, demonstrated by extensive experiments on both heterophilous and homophilous datasets.

Differentiable Cluster Graph Neural Network

TL;DR

DC-GNN tackles the dual challenges of long-range information propagation and heterophily by injecting a differentiable clustering inductive bias into GNN message passing. It achieves this by augmenting the original graph with a bipartite set of cluster-nodes (global and local) and optimizing an OT-based clustering objective \mathcal{L}_{\rm cluster}\u001f through a differentiable block-coordinate descent (DC-MsgPassing) that alternates between Sinkhorn-based assignment updates and closed-form embedding updates. The approach yields two key benefits: local cluster-nodes preserve graph structure and improve aggregation in heterophilous neighborhoods, while global cluster-nodes enable distant information transfer, collectively reducing oversquashing and improving performance on both heterophilous and homophilous datasets. Empirical results on 14 datasets demonstrate state-of-the-art accuracy across heterophilous graphs and strong results on homophilous graphs, with clear ablations showing the value of the clustering terms and regularizers for robust learning.

Abstract

Graph Neural Networks often struggle with long-range information propagation and in the presence of heterophilous neighborhoods. We address both challenges with a unified framework that incorporates a clustering inductive bias into the message passing mechanism, using additional cluster-nodes. Central to our approach is the formulation of an optimal transport based implicit clustering objective function. However, the algorithm for solving the implicit objective function needs to be differentiable to enable end-to-end learning of the GNN. To facilitate this, we adopt an entropy regularized objective function and propose an iterative optimization process, alternating between solving for the cluster assignments and updating the node/cluster-node embeddings. Notably, our derived closed-form optimization steps are themselves simple yet elegant message passing steps operating seamlessly on a bipartite graph of nodes and cluster-nodes. Our clustering-based approach can effectively capture both local and global information, demonstrated by extensive experiments on both heterophilous and homophilous datasets.
Paper Structure (50 sections, 2 theorems, 35 equations, 8 figures, 15 tables, 1 algorithm)

This paper contains 50 sections, 2 theorems, 35 equations, 8 figures, 15 tables, 1 algorithm.

Key Result

Theorem 3.3

Assuming the Sinkhorn--Knopp algorithm is run to convergence in each iteration, for any $\lambda > 0$, the value of $\mathcal{L}^{\lambda}_{cluster}$ produced by algorithm (Algorithm algo:algorithm) is guaranteed to converge.

Figures (8)

  • Figure 1: On the left is an instance where distant nodes are similar to each other. On the right is the heterophilous ego-neighborhood of node A where cluster patterns appear. Square boxes indicate conceptual cluster centroids.
  • Figure 2: Overview of DC-GNN. DC-MsgPassing is an iterative algorithm that minimizes $\mathcal{L}_{\rm cluster}^\lambda$ in each step of message passing, where $\mathcal{L}_{\rm cluster}^\lambda$ is an optimal transport based clustering objective function.
  • Figure 3: (a)-(b) Total effective resistance heatmap. (c) Accuracy for Tree-NeighborsMatch dataset.
  • Figure 4: Based on the original graph on the left, we construct a bipartite graph on the right by adding local and global cluster-nodes. For each node in the original graph, a set of local cluster-nodes, represented by the blue boxes at the top, is connected to its ego-neighborhood. For example, the ego-neighborhood of node a includes itself and its one-hop neighbor node b. Therefore the local cluster-nodes for node a are connected to a and b. Meanwhile, a set of global cluster-nodes are added and connected to all nodes in the original graph, as represented by the blue boxes at the bottom.
  • Figure 5: Total effective resistance heatmap of Erdos-Renyi random graphs at different sparsity levels. Number of nodes is 10 for all settings.
  • ...and 3 more figures

Theorems & Definitions (5)

  • Remark 3.1
  • Remark 3.2
  • Theorem 3.3: Convergence of DC-MsgPassing
  • Lemma B.1
  • proof