Table of Contents
Fetching ...

Probabilistically Rewired Message-Passing Neural Networks

Chendi Qian, Andrei Manolache, Kareem Ahmed, Zhe Zeng, Guy Van den Broeck, Mathias Niepert, Christopher Morris

TL;DR

This work devise probabilistically rewired MPNNs (PR-MPNNs), which learn to add relevant edges while omitting less beneficial ones, and identifies precise conditions under which they outperform purely randomized approaches.

Abstract

Message-passing graph neural networks (MPNNs) emerged as powerful tools for processing graph-structured input. However, they operate on a fixed input graph structure, ignoring potential noise and missing information. Furthermore, their local aggregation mechanism can lead to problems such as over-squashing and limited expressive power in capturing relevant graph structures. Existing solutions to these challenges have primarily relied on heuristic methods, often disregarding the underlying data distribution. Hence, devising principled approaches for learning to infer graph structures relevant to the given prediction task remains an open challenge. In this work, leveraging recent progress in exact and differentiable $k$-subset sampling, we devise probabilistically rewired MPNNs (PR-MPNNs), which learn to add relevant edges while omitting less beneficial ones. For the first time, our theoretical analysis explores how PR-MPNNs enhance expressive power, and we identify precise conditions under which they outperform purely randomized approaches. Empirically, we demonstrate that our approach effectively mitigates issues like over-squashing and under-reaching. In addition, on established real-world datasets, our method exhibits competitive or superior predictive performance compared to traditional MPNN models and recent graph transformer architectures.

Probabilistically Rewired Message-Passing Neural Networks

TL;DR

This work devise probabilistically rewired MPNNs (PR-MPNNs), which learn to add relevant edges while omitting less beneficial ones, and identifies precise conditions under which they outperform purely randomized approaches.

Abstract

Message-passing graph neural networks (MPNNs) emerged as powerful tools for processing graph-structured input. However, they operate on a fixed input graph structure, ignoring potential noise and missing information. Furthermore, their local aggregation mechanism can lead to problems such as over-squashing and limited expressive power in capturing relevant graph structures. Existing solutions to these challenges have primarily relied on heuristic methods, often disregarding the underlying data distribution. Hence, devising principled approaches for learning to infer graph structures relevant to the given prediction task remains an open challenge. In this work, leveraging recent progress in exact and differentiable -subset sampling, we devise probabilistically rewired MPNNs (PR-MPNNs), which learn to add relevant edges while omitting less beneficial ones. For the first time, our theoretical analysis explores how PR-MPNNs enhance expressive power, and we identify precise conditions under which they outperform purely randomized approaches. Empirically, we demonstrate that our approach effectively mitigates issues like over-squashing and under-reaching. In addition, on established real-world datasets, our method exhibits competitive or superior predictive performance compared to traditional MPNN models and recent graph transformer architectures.
Paper Structure (17 sections, 11 theorems, 15 equations, 7 figures, 12 tables)

This paper contains 17 sections, 11 theorems, 15 equations, 7 figures, 12 tables.

Key Result

Theorem 4.1

For sufficiently large $n$, for every $\varepsilon \in (0, 1)$ and $k > 0$, we have that for almost all pairs, in the sense of Bab+1980, of isomorphic $n$-order graphs $G$ and $H$ and all permutation-invariant, $1$-WL-equivalent functions $f \colon \mathfrak{A}_n \to \mathbb{R}^d$, $d > 0$, there ex

Figures (7)

  • Figure 1: Overview of the probabilistically rewired MPNN framework. PR-MPNNs use an upstream model to learn priors $\bm{\theta}$ for candidate edges, parameterizing a probability mass function conditioned on exactly-$k$ constraints. Subsequently, we sample multiple $k$-edge adjacency matrices (here: $k=1$) from this distribution, aggregate these matrices (here: subtraction), and use the resulting adjacency matrix as input to a downstream model, typically an MPNN, for the final predictions task. On the backward pass, the gradients of the loss $\ell$ regarding the parameters $\bm{\theta}$ are approximated through the derivative of the exactly-$k$ marginals in the direction of the gradients of the point-wise loss $\ell$ regarding the sampled adjacency matrix. We use recent work to make the computation of these marginals exact and differentiable, reducing both bias and variance.
  • Figure 2: Comparison between PR-MPNN and DropGNN on the 4-Cycles dataset. PR-MPNN rewiring is almost always better than randomly dropping nodes, and is always better with $10$ priors.
  • Figure 3: Example graph from the Trees-LeafCount test dataset with radius $4$ (left). PR-MPNN rewires the graph, allowing the downstream MPNN to obtain the label information from the leaves in one massage-passing step (right).
  • Figure 4: Test accuracy of our rewiring method on the Trees-NeighborsMatchAlon2020 dataset, compared to the reported accuracies from Mue+2023.
  • Figure 5: Example graphs used in the theoretical analysis.
  • ...and 2 more figures

Theorems & Definitions (17)

  • Theorem 4.1
  • Proposition 4.2
  • Theorem 4.3
  • Proposition 4.4
  • Theorem C.1: \ref{['sep-isomorphic-discrete_main']} in the main paper
  • Theorem C.2: Cyb+1992Les+1993
  • Lemma C.3
  • proof : Proof sketch
  • Lemma C.4
  • proof : Proof sketch
  • ...and 7 more