Table of Contents
Fetching ...

Learning Divergence Fields for Shift-Robust Graph Representations

Qitian Wu, Fan Nie, Chenxiao Yang, Junchi Yan

TL;DR

Problem: robust generalization under distribution shifts for interdependent data on graphs/manifolds. Approach: a geometric diffusion model with learnable divergence fields, where at each step $d^{(l)}$ is sampled from $p(\mathbf d^{(l)}|\mathbf z^{(l)})$ and the diffusion trajectory updates via $\mathbf z^{(l+1)} = \mathbf z^{(l)} + \alpha \sum_{v \in \mathcal N(u)} d_{uv}^{(l)} (\mathbf z_v^{(l)} - \mathbf z_u^{(l)})$, together with a causal regularization that optimizes a variational lower bound on $\log p_\theta(\mathbf y|\mathbf x, \mathcal G)$ under interventions. Contributions: (i) diffusion-on-graphs with stochastic divergences, (ii) a step-wise re-weighting regularization using $p_0(\mathbf d^{(l)})$ to approximate $p_\theta(\mathbf y|do(\mathbf x), \mathcal G)$, (iii) three practical backbones Glind-GCN, Glind-GAT, Glind-Trans, and a data-driven prior via a mixture of pseudo posteriors $ p_0(\mathbf d^{(l)}) = \frac{1}{T} \sum_{t=1}^T q(\mathbf d^{(l)}|\mathbf z^{(l)}=\tilde{\mathbf z}^{(l)}_t)$, (iv) extensive experiments demonstrating improved out-of-distribution generalization across datasets with observed and latent geometries. Significance: enables shift-robust graph representations applicable to diverse domains.

Abstract

Real-world data generation often involves certain geometries (e.g., graphs) that induce instance-level interdependence. This characteristic makes the generalization of learning models more difficult due to the intricate interdependent patterns that impact data-generative distributions and can vary from training to testing. In this work, we propose a geometric diffusion model with learnable divergence fields for the challenging generalization problem with interdependent data. We generalize the diffusion equation with stochastic diffusivity at each time step, which aims to capture the multi-faceted information flows among interdependent data. Furthermore, we derive a new learning objective through causal inference, which can guide the model to learn generalizable patterns of interdependence that are insensitive across domains. Regarding practical implementation, we introduce three model instantiations that can be considered as the generalized versions of GCN, GAT, and Transformers, respectively, which possess advanced robustness against distribution shifts. We demonstrate their promising efficacy for out-of-distribution generalization on diverse real-world datasets.

Learning Divergence Fields for Shift-Robust Graph Representations

TL;DR

Problem: robust generalization under distribution shifts for interdependent data on graphs/manifolds. Approach: a geometric diffusion model with learnable divergence fields, where at each step is sampled from and the diffusion trajectory updates via , together with a causal regularization that optimizes a variational lower bound on under interventions. Contributions: (i) diffusion-on-graphs with stochastic divergences, (ii) a step-wise re-weighting regularization using to approximate , (iii) three practical backbones Glind-GCN, Glind-GAT, Glind-Trans, and a data-driven prior via a mixture of pseudo posteriors , (iv) extensive experiments demonstrating improved out-of-distribution generalization across datasets with observed and latent geometries. Significance: enables shift-robust graph representations applicable to diverse domains.

Abstract

Real-world data generation often involves certain geometries (e.g., graphs) that induce instance-level interdependence. This characteristic makes the generalization of learning models more difficult due to the intricate interdependent patterns that impact data-generative distributions and can vary from training to testing. In this work, we propose a geometric diffusion model with learnable divergence fields for the challenging generalization problem with interdependent data. We generalize the diffusion equation with stochastic diffusivity at each time step, which aims to capture the multi-faceted information flows among interdependent data. Furthermore, we derive a new learning objective through causal inference, which can guide the model to learn generalizable patterns of interdependence that are insensitive across domains. Regarding practical implementation, we introduce three model instantiations that can be considered as the generalized versions of GCN, GAT, and Transformers, respectively, which possess advanced robustness against distribution shifts. We demonstrate their promising efficacy for out-of-distribution generalization on diverse real-world datasets.
Paper Structure (19 sections, 1 theorem, 30 equations, 5 figures, 4 tables)

This paper contains 19 sections, 1 theorem, 30 equations, 5 figures, 4 tables.

Key Result

Theorem 3.1

For any given diffusion model $p_\theta(\mathbf z^{(l+1)}|\mathbf z^{(l)}, \mathbf d^{(l)}, \mathcal{G})$, we have a lower bound of the deconfounded learning objective: $\log p_\theta(\mathbf y|do(\mathbf x), \mathcal{G}) \geq$ where $p_0(\mathbf d^{(l)})$ is a model-free prior distribution. In particular, the equality holds for Eqn. eqn-obj-reweight iff $q_\phi(\mathbf d^{(l)} | \mathbf z^{(l)} )

Figures (5)

  • Figure 1: The challenge of generalization with interdependent data involves distribution shifts regarding the underlying manifolds that define the proximity among data samples.
  • Figure 2: (a)$\sim$ (c) Dependence among random variables of interest (here we omit $\mathcal{G}$ for brevity since everything can be treated as conditioned on $\mathcal{G}$). (a) Causal dependence for data generation where the diffusivity $\mathbf d$ is the common cause of $\mathbf x$ and $\mathbf y$. (b) Deconfounded learning $p_\theta(\mathbf y|do(\mathbf x), \mathcal{G})$ which aims to cut off the dependence path from $\mathbf d$ to $\mathbf x$ in order to learn causal (a.k.a. stable) relation between $\mathbf x$ and $\mathbf y$ for generalization. (c) Our diffusion model whose feed-forward dynamics $\mathbf x_u = \mathbf z_u^{(0)} \rightarrow \mathbf z_u^{(1)} \rightarrow \cdots \rightarrow \mathbf z_u^{(L)} = \hat{\mathbf y}_u$ is given by a predictive distribution $p_\theta(\mathbf z^{(l+1)} | \mathbf z^{(l)}, \mathbf d^{(l)}, \mathcal{G})$ and a variational distribution $q_\phi(\mathbf d^{(l)}| \mathbf z^{(l)})$. (d) The geometric diffusion model that is optimized by a new learning objective (comprised of a supervised term and a regularization term) which can achieve the goal of deconfounded learning $p_\theta(\mathbf y|do(\mathbf x), \mathcal{G})$.
  • Figure 3: (a) Ablation studies for Glind on STL. (b) Performance of Glind on Arxiv with different $K$'s.
  • Figure 4: T-SNE visualization of input features for training instances (blue) and testing instances (red) on Arxiv.
  • Figure 5: (a) Ablation studies for Glind on CIFAR. (b) Performance of Glind on three testing sets of Twitch with different numbers of diffusivity $K$'s.

Theorems & Definitions (1)

  • Theorem 3.1