Stiffness: A New Perspective on Generalization in Neural Networks
Stanislav Fort, Paweł Krzysztof Nowak, Stanislaw Jastrzebski, Srini Narayanan
TL;DR
Stiffness reframes generalization in neural networks as gradient-alignment phenomena across datapoints, quantified by metrics $S_{\mathrm{sign}}$ and $S_{\cos}$ and linked to a dynamical critical length $\xi$ that marks the locality of updates. The approach is validated across vision tasks (MNIST, Fashion-MNIST, CIFAR-10/100) and NLP (MNLI with BERT), showing that stiffness tracks within-class transfer, cross-class generalization, and semantic structure such as CIFAR-100 super-classes and super-super-classes, while decaying with input-space distance. A key finding is that higher learning rates reduce $\xi$, making learned features more local and more easily perturbed by gradient updates, revealing a regularization effect beyond optimization speed. Overall, stiffness provides a unified diagnostic for generalization with potential applications in early stopping, architecture search, and meta-learning by promoting locality in representation updates.
Abstract
In this paper we develop a new perspective on generalization of neural networks by proposing and investigating the concept of a neural network stiffness. We measure how stiff a network is by looking at how a small gradient step in the network's parameters on one example affects the loss on another example. Higher stiffness suggests that a network is learning features that generalize. In particular, we study how stiffness depends on 1) class membership, 2) distance between data points in the input space, 3) training iteration, and 4) learning rate. We present experiments on MNIST, FASHION MNIST, and CIFAR-10/100 using fully-connected and convolutional neural networks, as well as on a transformer-based NLP model. We demonstrate the connection between stiffness and generalization, and observe its dependence on learning rate. When training on CIFAR-100, the stiffness matrix exhibits a coarse-grained behavior indicative of the model's awareness of super-class membership. In addition, we measure how stiffness between two data points depends on their mutual input-space distance, and establish the concept of a dynamical critical length -- a distance below which a parameter update based on a data point influences its neighbors.
