Table of Contents
Fetching ...

An accurate flatness measure to estimate the generalization performance of CNN models

Rahman Taleghani, Maryam Mohammadi, Francesco Marchetti

TL;DR

This work develops a flatness measure that is both exact and architecturally faithful for a broad and practically relevant class of CNNs and empirically investigates the proposed measure on families of CNNs trained on standard image-classification benchmarks.

Abstract

Flatness measures based on the spectrum or the trace of the Hessian of the loss are widely used as proxies for the generalization ability of deep networks. However, most existing definitions are either tailored to fully connected architectures, relying on stochastic estimators of the Hessian trace, or ignore the specific geometric structure of modern Convolutional Neural Networks (CNNs). In this work, we develop a flatness measure that is both exact and architecturally faithful for a broad and practically relevant class of CNNs. We first derive a closed-form expression for the trace of the Hessian of the cross-entropy loss with respect to convolutional kernels in networks that use global average pooling followed by a linear classifier. Building on this result, we then specialize the notion of relative flatness to convolutional layers and obtain a parameterization-aware flatness measure that properly accounts for the scaling symmetries and filter interactions induced by convolution and pooling. Finally, we empirically investigate the proposed measure on families of CNNs trained on standard image-classification benchmarks. The results obtained suggest that the proposed measure can serve as a robust tool to assess and compare the generalization performance of CNN models, and to guide the design of architecture and training choices in practice.

An accurate flatness measure to estimate the generalization performance of CNN models

TL;DR

This work develops a flatness measure that is both exact and architecturally faithful for a broad and practically relevant class of CNNs and empirically investigates the proposed measure on families of CNNs trained on standard image-classification benchmarks.

Abstract

Flatness measures based on the spectrum or the trace of the Hessian of the loss are widely used as proxies for the generalization ability of deep networks. However, most existing definitions are either tailored to fully connected architectures, relying on stochastic estimators of the Hessian trace, or ignore the specific geometric structure of modern Convolutional Neural Networks (CNNs). In this work, we develop a flatness measure that is both exact and architecturally faithful for a broad and practically relevant class of CNNs. We first derive a closed-form expression for the trace of the Hessian of the cross-entropy loss with respect to convolutional kernels in networks that use global average pooling followed by a linear classifier. Building on this result, we then specialize the notion of relative flatness to convolutional layers and obtain a parameterization-aware flatness measure that properly accounts for the scaling symmetries and filter interactions induced by convolution and pooling. Finally, we empirically investigate the proposed measure on families of CNNs trained on standard image-classification benchmarks. The results obtained suggest that the proposed measure can serve as a robust tool to assess and compare the generalization performance of CNN models, and to guide the design of architecture and training choices in practice.
Paper Structure (35 sections, 4 theorems, 74 equations, 14 figures, 5 tables)

This paper contains 35 sections, 4 theorems, 74 equations, 14 figures, 5 tables.

Key Result

Theorem 1

Let $x \in \mathbb{R}^{C_{\text{in}} \times H \times W}$ be an input image, and let $\{k_j\}_{j=1}^{C_{\text{out}}} \subset \mathbb{R}^d$ be the vectorized convolutional filters, where $d = C_{\text{in}} K_H K_W$. Suppose the output logits are computed as: Then, the trace of the Hessian of the cross-entropy loss with respect to all convolutional weights is:

Figures (14)

  • Figure 1: Last layers framework of the model.
  • Figure 4: Impact of learning rate and optimizer choice on flatness and generalization.
  • Figure 5: Comparison of training dynamics, relative flatness, and validation accuracy for the four best parameter sets of the evaluated model.
  • Figure 6: Effect of label noise on flatness and its predictive power. Left: Flatness increases with label noise level. Right: Spearman correlation between flatness and generalization gap, indicating flatness becomes more predictive under harder generalization scenarios.
  • Figure 7: Flatness dynamics during transfer learning. Top: Evolution of symbolic flatness (log scale) across strategies. Low learning rates maintain pre-trained flatness, while layer freezing leads to significant sharpness spikes. Bottom: Correlation between final accuracy and flatness (Spearman $\rho = -0.68, p < 10^{-6}$). Each diamond represents the final state of a specific strategy.
  • ...and 9 more figures

Theorems & Definitions (12)

  • Theorem 1: Trace of the Hessian Under GAP
  • proof
  • Corollary 2: Lipschitz Continuity of Flatness
  • proof
  • Remark 1: Monotonicity Under Gradient Descent
  • proof
  • Corollary 3
  • Definition 1: Convolutional Flatness under GAP
  • Theorem 4: Generalization Bound via Relative Flatness petzka2021relative
  • Example 1
  • ...and 2 more