Table of Contents
Fetching ...

Gradient-Weight Alignment as a Train-Time Proxy for Generalization in Classification Tasks

Florian A. Hölzl, Daniel Rueckert, Georgios Kaissis

TL;DR

The paper introduces Gradient-Weight Alignment (GWA), a train-time proxy for generalization in supervised classification. By measuring the coherence between per-sample gradients and current weights through an online estimator that leverages only the final-layer gradients, GWA captures training dynamics via the distribution of per-sample alignment scores and their kurtosis-corrected mean. Empirically, GWA-based early stopping matches or surpasses validation-based criteria across CNN and transformer architectures on CIFAR and ImageNet, and yields improvements in robustness to label and input noise. The approach also provides insight into data quality by linking mislabelled or difficult samples to negative alignment and supports effective fine-tuning strategies, with potential applicability to self-supervised and multi-modal settings.

Abstract

Robust validation metrics remain essential in contemporary deep learning, not only to detect overfitting and poor generalization, but also to monitor training dynamics. In the supervised classification setting, we investigate whether interactions between training data and model weights can yield such a metric that both tracks generalization during training and attributes performance to individual training samples. We introduce Gradient-Weight Alignment (GWA), quantifying the coherence between per-sample gradients and model weights. We show that effective learning corresponds to coherent alignment, while misalignment indicates deteriorating generalization. GWA is efficiently computable during training and reflects both sample-specific contributions and dataset-wide learning dynamics. Extensive experiments show that GWA accurately predicts optimal early stopping, enables principled model comparisons, and identifies influential training samples, providing a validation-set-free approach for model analysis directly from the training data.

Gradient-Weight Alignment as a Train-Time Proxy for Generalization in Classification Tasks

TL;DR

The paper introduces Gradient-Weight Alignment (GWA), a train-time proxy for generalization in supervised classification. By measuring the coherence between per-sample gradients and current weights through an online estimator that leverages only the final-layer gradients, GWA captures training dynamics via the distribution of per-sample alignment scores and their kurtosis-corrected mean. Empirically, GWA-based early stopping matches or surpasses validation-based criteria across CNN and transformer architectures on CIFAR and ImageNet, and yields improvements in robustness to label and input noise. The approach also provides insight into data quality by linking mislabelled or difficult samples to negative alignment and supports effective fine-tuning strategies, with potential applicability to self-supervised and multi-modal settings.

Abstract

Robust validation metrics remain essential in contemporary deep learning, not only to detect overfitting and poor generalization, but also to monitor training dynamics. In the supervised classification setting, we investigate whether interactions between training data and model weights can yield such a metric that both tracks generalization during training and attributes performance to individual training samples. We introduce Gradient-Weight Alignment (GWA), quantifying the coherence between per-sample gradients and model weights. We show that effective learning corresponds to coherent alignment, while misalignment indicates deteriorating generalization. GWA is efficiently computable during training and reflects both sample-specific contributions and dataset-wide learning dynamics. Extensive experiments show that GWA accurately predicts optimal early stopping, enables principled model comparisons, and identifies influential training samples, providing a validation-set-free approach for model analysis directly from the training data.

Paper Structure

This paper contains 30 sections, 3 equations, 17 figures, 5 tables, 1 algorithm.

Figures (17)

  • Figure 1: Gradient alignment among individual samples $\nabla\mathcal{L}_i$ as well as the model weights varies during training, with coherent per-sample gradient direction reflecting generalization. Line plots illustrate how captures gradient coherence and model performance at different time points $t$.
  • Figure 2: tracks validation accuracy and captures subtle training dynamics associated with generalization (left, center) better than LabelWave. Line plots depict normalized values of validation accuracy ($10\%$), LabelWave's prediction change and 's corrected mean across training. Markers indicate time step for early stopping according to each criterion. The underlying distribution of alignment scores $\gamma(x_i, \mathbf{w}_{T})$ (right) at time $T$ can be seen as a cross-section providing further insights into training. Label noise highly influences properties of the CIFAR-10-N distribution vs. CIFAR-10.
  • Figure 3: Maximum alignment $\mathbb{E}[\mathcal{A}_T]$ allows for comparing model performance across runs. Scatter plot (left) shows correlation between $\mathbb{E}[\mathcal{A}_T]$ and test accuracy on CIFAR-10 and with varying performance on its label noise variants for ConvNeXt and ViT. Correlation is even stronger when evaluating against robustness benchmark CIFAR-C (center). Pearson and Spearman correlation coefficients for all cases (right) corroborate visual findings ($p< 0.001$).
  • Figure 3: matches or outperforms other early stopping criteria when fine-tuning a ViT/B-16 pre-trained on ImageNet-21k. Top-1 test accuracy averaged across 3 seeds, min-max range below in gray.
  • Figure 4: Per-sample alignment scores $\gamma(x_i, \mathbf{w}_T)$ reveal insights into data characteristics and learning progression. Example images from CIFAR-10 and CIFAR-10-N with highest and lowest alignment scores at epochs 5, 50, and 90 of training. Images displayed for the dog and car classes.
  • ...and 12 more figures

Theorems & Definitions (2)

  • Definition 1: Per-Sample Alignment
  • Definition 2: Gradient-Weight Alignment