Table of Contents
Fetching ...

Residual Connections and Normalization Can Provably Prevent Oversmoothing in GNNs

Michael Scholkemper, Xinyi Wu, Ali Jadbabaie, Michael T. Schaub

TL;DR

The paper analyzes how residual connections and normalization layers influence oversmoothing in GNNs, using a linearized model to obtain precise convergence characterizations. It shows that initial residual connections preserve information in the Krylov subspace generated by the message-passing operator, preventing complete rank collapse, while batch normalization directs representations toward the top-k eigenspace of the centered operator, under broad conditions. A critical finding is that the centering step in normalization can distort the graph signal and structural eigenvectors, motivating GraphNormv2, which learns a centering projection to avoid undesired distortions. Empirical results across deep GCNs and various backbones corroborate the theory, showing improved maintenance of expressive subspaces and enhanced performance on several graph learning benchmarks.

Abstract

Residual connections and normalization layers have become standard design choices for graph neural networks (GNNs), and were proposed as solutions to the mitigate the oversmoothing problem in GNNs. However, how exactly these methods help alleviate the oversmoothing problem from a theoretical perspective is not well understood. In this work, we provide a formal and precise characterization of (linearized) GNNs with residual connections and normalization layers. We establish that (a) for residual connections, the incorporation of the initial features at each layer can prevent the signal from becoming too smooth, and determines the subspace of possible node representations; (b) batch normalization prevents a complete collapse of the output embedding space to a one-dimensional subspace through the individual rescaling of each column of the feature matrix. This results in the convergence of node representations to the top-$k$ eigenspace of the message-passing operator; (c) moreover, we show that the centering step of a normalization layer -- which can be understood as a projection -- alters the graph signal in message-passing in such a way that relevant information can become harder to extract. We therefore introduce a novel, principled normalization layer called GraphNormv2 in which the centering step is learned such that it does not distort the original graph signal in an undesirable way. Experimental results confirm the effectiveness of our method.

Residual Connections and Normalization Can Provably Prevent Oversmoothing in GNNs

TL;DR

The paper analyzes how residual connections and normalization layers influence oversmoothing in GNNs, using a linearized model to obtain precise convergence characterizations. It shows that initial residual connections preserve information in the Krylov subspace generated by the message-passing operator, preventing complete rank collapse, while batch normalization directs representations toward the top-k eigenspace of the centered operator, under broad conditions. A critical finding is that the centering step in normalization can distort the graph signal and structural eigenvectors, motivating GraphNormv2, which learns a centering projection to avoid undesired distortions. Empirical results across deep GCNs and various backbones corroborate the theory, showing improved maintenance of expressive subspaces and enhanced performance on several graph learning benchmarks.

Abstract

Residual connections and normalization layers have become standard design choices for graph neural networks (GNNs), and were proposed as solutions to the mitigate the oversmoothing problem in GNNs. However, how exactly these methods help alleviate the oversmoothing problem from a theoretical perspective is not well understood. In this work, we provide a formal and precise characterization of (linearized) GNNs with residual connections and normalization layers. We establish that (a) for residual connections, the incorporation of the initial features at each layer can prevent the signal from becoming too smooth, and determines the subspace of possible node representations; (b) batch normalization prevents a complete collapse of the output embedding space to a one-dimensional subspace through the individual rescaling of each column of the feature matrix. This results in the convergence of node representations to the top- eigenspace of the message-passing operator; (c) moreover, we show that the centering step of a normalization layer -- which can be understood as a projection -- alters the graph signal in message-passing in such a way that relevant information can become harder to extract. We therefore introduce a novel, principled normalization layer called GraphNormv2 in which the centering step is learned such that it does not distort the original graph signal in an undesirable way. Experimental results confirm the effectiveness of our method.
Paper Structure (45 sections, 20 theorems, 92 equations, 2 figures, 4 tables)

This paper contains 45 sections, 20 theorems, 92 equations, 2 figures, 4 tables.

Key Result

Proposition 4.0

Let $v \in \mathbb{R}$ s.t. $\left\| v \right\|_2 = 1$. If $\mu_v(X^{(0)}) > 0$, then w.h.p $\exists c > 0$ s.t. $\mu_v(X^{(t)}) \geq c$.

Figures (2)

  • Figure 1: Long Term behavior of GCN. Mean progression (over 10 independent trials) of $\mu_v(X^{(t)})$ and $\mathop{\mathrm{Rank}}\nolimits(X^{(t)})$ over $256$ iterations of message-passing in both linear and non-linear GCN, where $v$ corresponds to the dominant eigenvector for $D^{-1/2}A_{\mathop{\mathrm{adj}}\nolimits}D^{-1/2}$. In the linear case, $\mu_v(X^{(t)})$ remains constant for all methods except the vanilla GCNs, indicating that complete collapse to the dominant eigenspace does not happen. However, PairNorm does collapse in terms of rank, while the other methods maintain a rank greater than $2$. All the phenomena are explained by our theory. In the non-linear case, the models behave similarly. Notably, centering seems to prevent rank collapse in the non-linear case as PairNorm no longer collapses in rank.
  • Figure 2: Long-term behavior of GCN performance. Classification accuracy and standard deviation of GCN models of varying depth. The x-axis show the depth of the GCN while the y axis shows classification accuracy and standard deviations.

Theorems & Definitions (35)

  • Proposition 4.0
  • Proposition 4.0
  • Proposition 4.0
  • Remark 4.1
  • Remark 4.2
  • Proposition 4.2
  • Remark 4.3
  • Proposition 4.3
  • Proposition 4.3
  • Remark 4.4
  • ...and 25 more