Table of Contents
Fetching ...

Stiffness: A New Perspective on Generalization in Neural Networks

Stanislav Fort, Paweł Krzysztof Nowak, Stanislaw Jastrzebski, Srini Narayanan

TL;DR

Stiffness reframes generalization in neural networks as gradient-alignment phenomena across datapoints, quantified by metrics $S_{\mathrm{sign}}$ and $S_{\cos}$ and linked to a dynamical critical length $\xi$ that marks the locality of updates. The approach is validated across vision tasks (MNIST, Fashion-MNIST, CIFAR-10/100) and NLP (MNLI with BERT), showing that stiffness tracks within-class transfer, cross-class generalization, and semantic structure such as CIFAR-100 super-classes and super-super-classes, while decaying with input-space distance. A key finding is that higher learning rates reduce $\xi$, making learned features more local and more easily perturbed by gradient updates, revealing a regularization effect beyond optimization speed. Overall, stiffness provides a unified diagnostic for generalization with potential applications in early stopping, architecture search, and meta-learning by promoting locality in representation updates.

Abstract

In this paper we develop a new perspective on generalization of neural networks by proposing and investigating the concept of a neural network stiffness. We measure how stiff a network is by looking at how a small gradient step in the network's parameters on one example affects the loss on another example. Higher stiffness suggests that a network is learning features that generalize. In particular, we study how stiffness depends on 1) class membership, 2) distance between data points in the input space, 3) training iteration, and 4) learning rate. We present experiments on MNIST, FASHION MNIST, and CIFAR-10/100 using fully-connected and convolutional neural networks, as well as on a transformer-based NLP model. We demonstrate the connection between stiffness and generalization, and observe its dependence on learning rate. When training on CIFAR-100, the stiffness matrix exhibits a coarse-grained behavior indicative of the model's awareness of super-class membership. In addition, we measure how stiffness between two data points depends on their mutual input-space distance, and establish the concept of a dynamical critical length -- a distance below which a parameter update based on a data point influences its neighbors.

Stiffness: A New Perspective on Generalization in Neural Networks

TL;DR

Stiffness reframes generalization in neural networks as gradient-alignment phenomena across datapoints, quantified by metrics and and linked to a dynamical critical length that marks the locality of updates. The approach is validated across vision tasks (MNIST, Fashion-MNIST, CIFAR-10/100) and NLP (MNLI with BERT), showing that stiffness tracks within-class transfer, cross-class generalization, and semantic structure such as CIFAR-100 super-classes and super-super-classes, while decaying with input-space distance. A key finding is that higher learning rates reduce , making learned features more local and more easily perturbed by gradient updates, revealing a regularization effect beyond optimization speed. Overall, stiffness provides a unified diagnostic for generalization with potential applications in early stopping, architecture search, and meta-learning by promoting locality in representation updates.

Abstract

In this paper we develop a new perspective on generalization of neural networks by proposing and investigating the concept of a neural network stiffness. We measure how stiff a network is by looking at how a small gradient step in the network's parameters on one example affects the loss on another example. Higher stiffness suggests that a network is learning features that generalize. In particular, we study how stiffness depends on 1) class membership, 2) distance between data points in the input space, 3) training iteration, and 4) learning rate. We present experiments on MNIST, FASHION MNIST, and CIFAR-10/100 using fully-connected and convolutional neural networks, as well as on a transformer-based NLP model. We demonstrate the connection between stiffness and generalization, and observe its dependence on learning rate. When training on CIFAR-100, the stiffness matrix exhibits a coarse-grained behavior indicative of the model's awareness of super-class membership. In addition, we measure how stiffness between two data points depends on their mutual input-space distance, and establish the concept of a dynamical critical length -- a distance below which a parameter update based on a data point influences its neighbors.

Paper Structure

This paper contains 20 sections, 8 equations, 21 figures.

Figures (21)

  • Figure 1: A diagram illustrating a) the concept of stiffness and b) dynamical critical length $\xi$. A small gradient update to the network's weights based on example $X_1$ decreases loss on some examples (stiff), and increases it on others (anti-stiff). We call the characteristic distance over which datapoints are stiff the dynamical critical length $\xi$.
  • Figure 2: A diagram illustrating the concept of stiffness. It can be viewed in two equivalent ways: a) as the change in loss at a datapoint induced by the application of a gradient update based on another datapoint, and b) the alignment of loss gradients computed at the two datapoints. These two descriptions are mathematically equivalent.
  • Figure 3: The evolution of training and validation loss (top panel), within-class stiffness (central panel) and between-classes stiffness (bottom panel) during training. The onset of over-fitting is marked in orange. After that, both within-class and between-classes stiffness regress to 0. The same effect is visible in stiffness measured between two training set datapoints, one training and one validation datapoint, and two validation set datapoints.
  • Figure 4: Class membership dependence of stiffness for a CNN on FASHION MNIST at 4 different stages of training. The figure shows stiffness between train-train, train-val and val-val pairs of images, as well as the sign and cosine metrics.
  • Figure 5: Stiffness as a function of epoch. The plots summarize the evolution of within-class and between-classes stiffness measures as a function of epoch of training for a CNN on FASHION MNIST.
  • ...and 16 more figures