Table of Contents
Fetching ...

A Quasi-Wasserstein Loss for Learning Graph Neural Networks

Minjie Cheng, Hongteng Xu

TL;DR

This study designs a "Quasi-Wasserstein'' distance between the observed multi-dimensional node labels and their estimations, optimizing the label transport defined on graph edges, leading to new learning and prediction paradigms of GNNs.

Abstract

When learning graph neural networks (GNNs) in node-level prediction tasks, most existing loss functions are applied for each node independently, even if node embeddings and their labels are non-i.i.d. because of their graph structures. To eliminate such inconsistency, in this study we propose a novel Quasi-Wasserstein (QW) loss with the help of the optimal transport defined on graphs, leading to new learning and prediction paradigms of GNNs. In particular, we design a ``Quasi-Wasserstein'' distance between the observed multi-dimensional node labels and their estimations, optimizing the label transport defined on graph edges. The estimations are parameterized by a GNN in which the optimal label transport may determine the graph edge weights optionally. By reformulating the strict constraint of the label transport to a Bregman divergence-based regularizer, we obtain the proposed Quasi-Wasserstein loss associated with two efficient solvers learning the GNN together with optimal label transport. When predicting node labels, our model combines the output of the GNN with the residual component provided by the optimal label transport, leading to a new transductive prediction paradigm. Experiments show that the proposed QW loss applies to various GNNs and helps to improve their performance in node-level classification and regression tasks. The code of this work can be found at \url{https://github.com/SDS-Lab/QW_Loss}.

A Quasi-Wasserstein Loss for Learning Graph Neural Networks

TL;DR

This study designs a "Quasi-Wasserstein'' distance between the observed multi-dimensional node labels and their estimations, optimizing the label transport defined on graph edges, leading to new learning and prediction paradigms of GNNs.

Abstract

When learning graph neural networks (GNNs) in node-level prediction tasks, most existing loss functions are applied for each node independently, even if node embeddings and their labels are non-i.i.d. because of their graph structures. To eliminate such inconsistency, in this study we propose a novel Quasi-Wasserstein (QW) loss with the help of the optimal transport defined on graphs, leading to new learning and prediction paradigms of GNNs. In particular, we design a ``Quasi-Wasserstein'' distance between the observed multi-dimensional node labels and their estimations, optimizing the label transport defined on graph edges. The estimations are parameterized by a GNN in which the optimal label transport may determine the graph edge weights optionally. By reformulating the strict constraint of the label transport to a Bregman divergence-based regularizer, we obtain the proposed Quasi-Wasserstein loss associated with two efficient solvers learning the GNN together with optimal label transport. When predicting node labels, our model combines the output of the GNN with the residual component provided by the optimal label transport, leading to a new transductive prediction paradigm. Experiments show that the proposed QW loss applies to various GNNs and helps to improve their performance in node-level classification and regression tasks. The code of this work can be found at \url{https://github.com/SDS-Lab/QW_Loss}.
Paper Structure (41 sections, 3 theorems, 21 equations, 5 figures, 9 tables, 2 algorithms)

This paper contains 41 sections, 3 theorems, 21 equations, 5 figures, 9 tables, 2 algorithms.

Key Result

Theorem 1

Given an undirected graph $G(\mathcal{V},\mathcal{E})$, with edge weights $\bm{w}\in[0,\infty)^{|\mathcal{E}|}$ and a matrix $\bm{S}_{\mathcal{V}}\in\{0,\pm 1\}^{|\mathcal{V}|\times\mathcal{E}}$ defined in (eq:topo), the $W_1$ in (eq:got_directed) is a metric in $Range(\bm{S}_{\mathcal{V}})$, and th

Figures (5)

  • Figure 1: The scheme of our QW loss and the corresponding learning paradigm. Given a graph, whose node features are denoted as blue circles and partially-observed node labels are denoted as blue stems, a GNN embeds the nodes and estimates their labels (denoted as orange stems). By minimizing the QW loss, we obtain the optimal label transport (denoted as the dotted red arrows on edges) between the real and estimated node labels. Optionally, the optimal label transport determines the weights of graph edges (through an edge weight predictor). The final predictions are the combinations of the optimal label transport and the estimated node labels.
  • Figure 2: The runtime of different learning methods on the Photo graph.
  • Figure 3: The histogram of $\bm{F}$'s values for different GNNs.
  • Figure 4: Illustrations of the learning methods' performance given different amounts of labeled nodes.
  • Figure 5: The impact of $\lambda$ on the learning results.

Theorems & Definitions (6)

  • Theorem 1: Metric Property
  • Theorem 2: Monotonicity
  • Theorem 3
  • proof
  • proof
  • proof