GLL: A Differentiable Graph Learning Layer for Neural Networks
Jason Brown, Bohan Chen, Harris Hardiman-Mostow, Jeff Calder, Andrea L. Bertozzi
TL;DR
This work introduces the Graph Learning Layer (GLL), a differentiable layer that integrates graph-based label propagation directly into neural network training. By deriving exact adjoint-based backpropagation through graph Laplace equations and similarity-graph construction, GLL jointly learns feature representations while performing graph-informed classification, replacing the traditional MLP head and softmax. The approach yields smoother embeddings, improved generalization, and notably stronger adversarial robustness across datasets and architectures, particularly at low label rates. Extensive experiments, including large-scale CIFAR-10 and EMNIST and ablations on toy and over-parameterized models, demonstrate the practical viability and benefits of end-to-end differentiable graph learning in deep networks.
Abstract
Standard deep learning architectures used for classification generate label predictions with a projection head and softmax activation function. Although successful, these methods fail to leverage the relational information between samples for generating label predictions. In recent works, graph-based learning techniques, namely Laplace learning, have been heuristically combined with neural networks for both supervised and semi-supervised learning (SSL) tasks. However, prior works approximate the gradient of the loss function with respect to the graph learning algorithm or decouple the processes; end-to-end integration with neural networks is not achieved. In this work, we derive backpropagation equations, via the adjoint method, for inclusion of a general family of graph learning layers into a neural network. The resulting method, distinct from graph neural networks, allows us to precisely integrate similarity graph construction and graph Laplacian-based label propagation into a neural network layer, replacing a projection head and softmax activation function for general classification task. Our experimental results demonstrate smooth label transitions across data, improved generalization and robustness to adversarial attacks, and improved training dynamics compared to a standard softmax-based approach.
