Table of Contents
Fetching ...

GraphMETRO: Mitigating Complex Graph Distribution Shifts via Mixture of Aligned Experts

Shirley Wu, Kaidi Cao, Bruno Ribeiro, James Zou, Jure Leskovec

TL;DR

GraphMETRO, a Graph Neural Network architecture that models natural diversity and captures complex distributional shifts, achieves state-of-the-art results on four datasets from the GOOD benchmark, which is comprised of complex and natural real-world distribution shifts.

Abstract

Graph data are inherently complex and heterogeneous, leading to a high natural diversity of distributional shifts. However, it remains unclear how to build machine learning architectures that generalize to the complex distributional shifts naturally occurring in the real world. Here, we develop GraphMETRO, a Graph Neural Network architecture that models natural diversity and captures complex distributional shifts. GraphMETRO employs a Mixture-of-Experts (MoE) architecture with a gating model and multiple expert models, where each expert model targets a specific distributional shift to produce a referential representation w.r.t. a reference model, and the gating model identifies shift components. Additionally, we design a novel objective that aligns the representations from different expert models to ensure reliable optimization. GraphMETRO achieves state-of-the-art results on four datasets from the GOOD benchmark, which is comprised of complex and natural real-world distribution shifts, improving by 67% and 4.2% on the WebKB and Twitch datasets. Code and data are available at https://github.com/Wuyxin/GraphMETRO.

GraphMETRO: Mitigating Complex Graph Distribution Shifts via Mixture of Aligned Experts

TL;DR

GraphMETRO, a Graph Neural Network architecture that models natural diversity and captures complex distributional shifts, achieves state-of-the-art results on four datasets from the GOOD benchmark, which is comprised of complex and natural real-world distribution shifts.

Abstract

Graph data are inherently complex and heterogeneous, leading to a high natural diversity of distributional shifts. However, it remains unclear how to build machine learning architectures that generalize to the complex distributional shifts naturally occurring in the real world. Here, we develop GraphMETRO, a Graph Neural Network architecture that models natural diversity and captures complex distributional shifts. GraphMETRO employs a Mixture-of-Experts (MoE) architecture with a gating model and multiple expert models, where each expert model targets a specific distributional shift to produce a referential representation w.r.t. a reference model, and the gating model identifies shift components. Additionally, we design a novel objective that aligns the representations from different expert models to ensure reliable optimization. GraphMETRO achieves state-of-the-art results on four datasets from the GOOD benchmark, which is comprised of complex and natural real-world distribution shifts, improving by 67% and 4.2% on the WebKB and Twitch datasets. Code and data are available at https://github.com/Wuyxin/GraphMETRO.
Paper Structure (25 sections, 3 theorems, 21 equations, 5 figures, 7 tables)

This paper contains 25 sections, 3 theorems, 21 equations, 5 figures, 7 tables.

Key Result

Theorem 1

For any graph $\mathcal{G}$ and shift component $\tau_i$, the encoder $h$ satisfies:

Figures (5)

  • Figure 1: An example on WebKB webkbgood. It illustrates (1) The distribution shift from source to target (the thick arrow in the upper right) and (2) Instance-wise heterogeneity in the target distribution (the thin arrows pointing to $u_1$ and $u_2$).
  • Figure 2: Overview of GraphMETRO on graph classification tasks. (a) High-level Concept: As a simple example, the distribution shift from a target graph $\mathcal{G}\in\mathcal{D}_t$ to a source distribution $\mathcal{D}_s$ is decomposed along three shift dimensions: graph size ($\xi_1$), node degree ($\xi_2$), and feature noise ($\xi_3$). Note that the shift components can be customized and expanded based on downstream tasks. (b) Architecture: Given an input graph, the gating model $\mu$ decomposes the instance-specific distribution shift into the contributions from the shift components. Then, each expert model $\xi_i$$(i > 0)$ is tasked with generating referential invariant representations (cf. Section \ref{['sec:method']} for the definition) w.r.t. an assigned shift component. $\xi_0$ is a reference model used for aligning the representation spaces of the expert models. The final representation is aggregated from the experts' output and is referentially invariant to any distribution shifts, which is then input to the classifier.
  • Figure 3: Accuracy on synthetic distribution shifts. The first row shows the testing accuracy on single shift components. We label the distribution by the clockwise order. The second row shows the testing accuracy on distribution shifts with multiple shift components, where each testing distribution is a composition of two different transformations. For example, (1, 5) denotes a testing distribution where each graph is controlled by random subgraph (1) and noisy feature (5) shift components. We include the numerical values in Appendix \ref{['app:numerical']}.
  • Figure 4: (a) Invariance matrix on the Twitter dataset. Lighter colors indicate a higher invariance of representations produced by each expert. Small values on the diagonal elements of the invariance matrix indicate that each expert excels at generating invariant representations w.r.t. the specific shift component. (b) Mixture of distribution shifts identified by GraphMETRO. Higher values indicate a strong shift component in the testing distribution.
  • Figure 5: Impact of transform function choices on model performance. Each number of transform functions corresponds to a specific set of transformations.

Theorems & Definitions (7)

  • Definition 1: Referential Invariant Representation
  • Theorem 1: Shift-Invariance
  • Proof 1
  • Theorem 2: Composition of Shifts
  • Proof 2
  • Theorem 3: Generalization Bound
  • Proof 3