Table of Contents
Fetching ...

Empowering Graph Invariance Learning with Deep Spurious Infomax

Tianjun Yao, Yongqiang Chen, Zhenhao Chen, Kai Hu, Zhiqiang Shen, Kun Zhang

TL;DR

This work tackles graph OOD generalization by exposing brittleness in prior invariance methods that depend on fixed spurious correlations. It introduces EQuAD, a flexible Encoding-QuAntifying-Decorrelation framework that uses infomax-based self-supervision to reveal spurious features and then decorrelates them to recover invariant content for robust graph classification. The authors provide theoretical and empirical evidence that infomax can isolate spurious features under mild conditions and demonstrate strong improvements on both synthetic FIIF/PIIF settings and real-world datasets, including DrugOOD and OGB benchmarks. The approach offers a modular, plug-in pathway to stable OOD performance and has potential applicability beyond graphs to other data modalities.

Abstract

Recently, there has been a surge of interest in developing graph neural networks that utilize the invariance principle on graphs to generalize the out-of-distribution (OOD) data. Due to the limited knowledge about OOD data, existing approaches often pose assumptions about the correlation strengths of the underlying spurious features and the target labels. However, this prior is often unavailable and will change arbitrarily in the real-world scenarios, which may lead to severe failures of the existing graph invariance learning methods. To bridge this gap, we introduce a novel graph invariance learning paradigm, which induces a robust and general inductive bias. The paradigm is built upon the observation that the infomax principle encourages learning spurious features regardless of spurious correlation strengths. We further propose the EQuAD framework that realizes this learning paradigm and employs tailored learning objectives that provably elicit invariant features by disentangling them from the spurious features learned through infomax. Notably, EQuAD shows stable and enhanced performance across different degrees of bias in synthetic datasets and challenging real-world datasets up to $31.76\%$. Our code is available at \url{https://github.com/tianyao-aka/EQuAD}.

Empowering Graph Invariance Learning with Deep Spurious Infomax

TL;DR

This work tackles graph OOD generalization by exposing brittleness in prior invariance methods that depend on fixed spurious correlations. It introduces EQuAD, a flexible Encoding-QuAntifying-Decorrelation framework that uses infomax-based self-supervision to reveal spurious features and then decorrelates them to recover invariant content for robust graph classification. The authors provide theoretical and empirical evidence that infomax can isolate spurious features under mild conditions and demonstrate strong improvements on both synthetic FIIF/PIIF settings and real-world datasets, including DrugOOD and OGB benchmarks. The approach offers a modular, plug-in pathway to stable OOD performance and has potential applicability beyond graphs to other data modalities.

Abstract

Recently, there has been a surge of interest in developing graph neural networks that utilize the invariance principle on graphs to generalize the out-of-distribution (OOD) data. Due to the limited knowledge about OOD data, existing approaches often pose assumptions about the correlation strengths of the underlying spurious features and the target labels. However, this prior is often unavailable and will change arbitrarily in the real-world scenarios, which may lead to severe failures of the existing graph invariance learning methods. To bridge this gap, we introduce a novel graph invariance learning paradigm, which induces a robust and general inductive bias. The paradigm is built upon the observation that the infomax principle encourages learning spurious features regardless of spurious correlation strengths. We further propose the EQuAD framework that realizes this learning paradigm and employs tailored learning objectives that provably elicit invariant features by disentangling them from the spurious features learned through infomax. Notably, EQuAD shows stable and enhanced performance across different degrees of bias in synthetic datasets and challenging real-world datasets up to . Our code is available at \url{https://github.com/tianyao-aka/EQuAD}.
Paper Structure (34 sections, 4 theorems, 26 equations, 12 figures, 7 tables, 1 algorithm)

This paper contains 34 sections, 4 theorems, 26 equations, 12 figures, 7 tables, 1 algorithm.

Key Result

Theorem 4.1

Given the same data generation process as in Fig. fig:scm with Shannon entropy $H(S)=H(C)=\delta_f$, assuming the node representations encode proper information of the underlying latent factors, i.e., $\delta_r\geq I(\widehat{{\bm{h}}}_{i};C)- I(\widehat{{\bm{h}}}_{i};S)\geq \delta_l, \forall i\in G

Figures (12)

  • Figure 1: Structural causal models for graph generation.
  • Figure 2: Investigation of the representation quality of ERM and Infomax-Based SSL in capturing spurious features. The experimental results implies that Infomax-Based SSL primarily learns spurious correlations. Further details on the experimental setup are provided in Appendix \ref{['app:rep_quality']}.
  • Figure 3: The overall framework of EQuAD. With an input graph consisting of $G_c$ (shown in blue) and $G_s$ (shown in red), the following procedures in EQuAD are illustrated: (a) Encoding and quantifying: First, the infomax-based SSL is performed to learn a collection of spurious representations (3 in this case), followed by $g(\cdot)$ to obtain the corresponding prediction logits as targets. (b+c) Decorrelation: In (b), a GNN encoder $h(\cdot)$ is re-trained from scratch to generate $\widehat{{\bm{h}}}_c$ followed by a classifier $\rho(\cdot)$ to get the prediction $\widehat{y}_i$; In (c), $\widehat{y}_i$ is fed into loss function $\mathcal{L}=\mathcal{L}_{GT}+\lambda\mathcal{L}_{Inv}$ for learning invariant features, where $\widehat{\mathbf{s}}_i$ and $y_i$ serve as targets. (d) Detailed illustration of the process for model-specific reweighting. Finally, $\widehat{y}_i$ and $\widehat{\mathbf{s}}_i$ are obtained using the same classifier,i.e., $\rho(\cdot)=\rho^{\prime}(\cdot)$.
  • Figure 4: Logits distribution of SPMotif datasets for the Cycle class, where the samples are devided into two subgroups: the first subgroup consists of samples with a high correlation between $S$ and $Y$, and the second subgroup contains samples with a low correlation between $S$ and $Y$.
  • Figure 5: Investigation of the central moment distance of latent embeddings from training and validation set respectively using ERM and EQuAD. For all three classes across various SPMotif datasets, the distances for representations obtained via EQuAD are notably smaller compared to those from ERM.
  • ...and 7 more figures

Theorems & Definitions (9)

  • Theorem 4.1
  • Theorem 5.1
  • Theorem 4.1
  • proof
  • Proposition 4.2
  • proof
  • proof
  • proof
  • Definition 9.1: Two-piece graphs