Table of Contents
Fetching ...

Backward Oversmoothing: why is it hard to train deep Graph Neural Networks?

Nicolas Keriven

TL;DR

This paper introduces backward oversmoothing, an optimization-centered lens on the well-known oversmoothing phenomenon in deep Graph Neural Networks (GNNs). It shows that backpropagated errors experience smoothing in a way that interacts with forward smoothing, producing a nontrivial, mid-layer effect and, under mild conditions, many spurious near-stationary points; notably, once the output layer is trained, the entire network tends to a global near-stationary point, often with high loss. The results are demonstrated for vanilla GNNs and are shown not to extend to standard MLPs, underscoring fundamental differences in the optimization landscapes of these architectures. While the paper does not propose new remedies, it sets a theoretical foundation for understanding why deep GNNs are difficult to train and motivates future work on optimization strategies that directly counteract these spurious flat regions.

Abstract

Oversmoothing has long been identified as a major limitation of Graph Neural Networks (GNNs): input node features are smoothed at each layer and converge to a non-informative representation, if the weights of the GNN are sufficiently bounded. This assumption is crucial: if, on the contrary, the weights are sufficiently large, then oversmoothing may not happen. Theoretically, GNN could thus learn to not oversmooth. However it does not really happen in practice, which prompts us to examine oversmoothing from an optimization point of view. In this paper, we analyze backward oversmoothing, that is, the notion that backpropagated errors used to compute gradients are also subject to oversmoothing from output to input. With non-linear activation functions, we outline the key role of the interaction between forward and backward smoothing. Moreover, we show that, due to backward oversmoothing, GNNs provably exhibit many spurious stationary points: as soon as the last layer is trained, the whole GNN is at a stationary point. As a result, we can exhibit regions where gradients are near-zero while the loss stays high. The proof relies on the fact that, unlike forward oversmoothing, backward errors are subjected to a linear oversmoothing even in the presence of non-linear activation function, such that the average of the output error plays a key role. Additionally, we show that this phenomenon is specific to deep GNNs, and exhibit counter-example Multi-Layer Perceptron. This paper is a step toward a more complete comprehension of the optimization landscape specific to GNNs.

Backward Oversmoothing: why is it hard to train deep Graph Neural Networks?

TL;DR

This paper introduces backward oversmoothing, an optimization-centered lens on the well-known oversmoothing phenomenon in deep Graph Neural Networks (GNNs). It shows that backpropagated errors experience smoothing in a way that interacts with forward smoothing, producing a nontrivial, mid-layer effect and, under mild conditions, many spurious near-stationary points; notably, once the output layer is trained, the entire network tends to a global near-stationary point, often with high loss. The results are demonstrated for vanilla GNNs and are shown not to extend to standard MLPs, underscoring fundamental differences in the optimization landscapes of these architectures. While the paper does not propose new remedies, it sets a theoretical foundation for understanding why deep GNNs are difficult to train and motivates future work on optimization strategies that directly counteract these spurious flat regions.

Abstract

Oversmoothing has long been identified as a major limitation of Graph Neural Networks (GNNs): input node features are smoothed at each layer and converge to a non-informative representation, if the weights of the GNN are sufficiently bounded. This assumption is crucial: if, on the contrary, the weights are sufficiently large, then oversmoothing may not happen. Theoretically, GNN could thus learn to not oversmooth. However it does not really happen in practice, which prompts us to examine oversmoothing from an optimization point of view. In this paper, we analyze backward oversmoothing, that is, the notion that backpropagated errors used to compute gradients are also subject to oversmoothing from output to input. With non-linear activation functions, we outline the key role of the interaction between forward and backward smoothing. Moreover, we show that, due to backward oversmoothing, GNNs provably exhibit many spurious stationary points: as soon as the last layer is trained, the whole GNN is at a stationary point. As a result, we can exhibit regions where gradients are near-zero while the loss stays high. The proof relies on the fact that, unlike forward oversmoothing, backward errors are subjected to a linear oversmoothing even in the presence of non-linear activation function, such that the average of the output error plays a key role. Additionally, we show that this phenomenon is specific to deep GNNs, and exhibit counter-example Multi-Layer Perceptron. This paper is a step toward a more complete comprehension of the optimization landscape specific to GNNs.

Paper Structure

This paper contains 39 sections, 11 theorems, 80 equations, 2 figures.

Key Result

Theorem 1

Under Assumptions ass:P and ass:sigma we have In particular, if $s\leqslant \lambda^{-\alpha}$ for $0\leqslant \alpha < 1$, then $\mathcal{E}(X^{(k)}) \leqslant \lambda^{(1-\alpha)k} \mathcal{E}(X^{(0)})$.

Figures (2)

  • Figure 1: Left, center: norm of gradients with respect to epochs, for each layer, for deep GNN (left) or MLP (center) with $k=40$. The output layer is outlined. Right: loss with respect to epochs, for shallow ($k=5$) and deep ($k=40$) GNN or MLP. Node classification task on Contextual Stochastic Block Model (CSBM) data Deshpande2018.
  • Figure 2: Pairwise differences \ref{['eq:dirichlet']} for the forward signal (left) and backward signal (right) with respect to the layer index $k$, at initialization, on a node classification task on CSBM synthetic data Deshpande2018.

Theorems & Definitions (25)

  • Example 1: Regression
  • Example 2: Classification
  • Theorem 1: Forward oversmoothing Oono2020
  • Theorem 2
  • Corollary 1
  • Proposition 1
  • Definition 1: $\delta$-stationary point
  • Theorem 3: A stationary output layer is a global stationary point
  • Corollary 2: Spurious stationary points
  • Proposition 2: MLPs are not GNNs
  • ...and 15 more