Table of Contents
Fetching ...

Explicit regularization and implicit bias in deep network classifiers trained with the square loss

Tomaso Poggio, Qianli Liao

TL;DR

This work analyzes why deep ReLU nets trained with the square loss generalize well by examining gradient flow under normalization (Batch Normalization or Weight Normalization) and explicit regularization (Weight Decay). It shows that, under average separability, gradient flow converges to interpolating solutions with the smallest norm $\rho_{eq}=\dfrac{\sum_n y_n f_n}{\lambda+\sum_n f_n^2}$, corresponding to the largest margin and strongest generalization; BN and WD render the dynamics well-posed, while absence of WD/BN yields initialization-dependent implicit regularization. The theory connects to Neural Collapse via constraints on layer weights at convergence and predicts specific relations among layer matrices that must hold at critical points. Overall, normalization and explicit regularization play pivotal roles in driving deep network solutions toward high-margin, low-norm interpolants with favorable generalization properties, offering testable predictions about training dynamics.

Abstract

Deep ReLU networks trained with the square loss have been observed to perform well in classification tasks. We provide here a theoretical justification based on analysis of the associated gradient flow. We show that convergence to a solution with the absolute minimum norm is expected when normalization techniques such as Batch Normalization (BN) or Weight Normalization (WN) are used together with Weight Decay (WD). The main property of the minimizers that bounds their expected error is the norm: we prove that among all the close-to-interpolating solutions, the ones associated with smaller Frobenius norms of the unnormalized weight matrices have better margin and better bounds on the expected classification error. With BN but in the absence of WD, the dynamical system is singular. Implicit dynamical regularization -- that is zero-initial conditions biasing the dynamics towards high margin solutions -- is also possible in the no-BN and no-WD case. The theory yields several predictions, including the role of BN and weight decay, aspects of Papyan, Han and Donoho's Neural Collapse and the constraints induced by BN on the network weights.

Explicit regularization and implicit bias in deep network classifiers trained with the square loss

TL;DR

This work analyzes why deep ReLU nets trained with the square loss generalize well by examining gradient flow under normalization (Batch Normalization or Weight Normalization) and explicit regularization (Weight Decay). It shows that, under average separability, gradient flow converges to interpolating solutions with the smallest norm , corresponding to the largest margin and strongest generalization; BN and WD render the dynamics well-posed, while absence of WD/BN yields initialization-dependent implicit regularization. The theory connects to Neural Collapse via constraints on layer weights at convergence and predicts specific relations among layer matrices that must hold at critical points. Overall, normalization and explicit regularization play pivotal roles in driving deep network solutions toward high-margin, low-norm interpolants with favorable generalization properties, offering testable predictions about training dynamics.

Abstract

Deep ReLU networks trained with the square loss have been observed to perform well in classification tasks. We provide here a theoretical justification based on analysis of the associated gradient flow. We show that convergence to a solution with the absolute minimum norm is expected when normalization techniques such as Batch Normalization (BN) or Weight Normalization (WN) are used together with Weight Decay (WD). The main property of the minimizers that bounds their expected error is the norm: we prove that among all the close-to-interpolating solutions, the ones associated with smaller Frobenius norms of the unnormalized weight matrices have better margin and better bounds on the expected classification error. With BN but in the absence of WD, the dynamical system is singular. Implicit dynamical regularization -- that is zero-initial conditions biasing the dynamics towards high margin solutions -- is also possible in the no-BN and no-WD case. The theory yields several predictions, including the role of BN and weight decay, aspects of Papyan, Han and Donoho's Neural Collapse and the constraints induced by BN on the network weights.

Paper Structure

This paper contains 8 sections, 1 theorem, 8 equations, 11 figures.

Key Result

Lemma 1

If the gradient flow with normalization and weight decay converges to an interpolating solution with near-zero square loss, the following properties hold:

Figures (11)

  • Figure 1: ConvNet with Batch Normalization and Weight Decay Binary classification on two classes from CIFAR-10, trained with MSE loss. The model is a very simple network with 4 layers of fully-connected Layers. ReLU nonlinearity is used. Batch normalization is used. The weight matrices of all layers are initialized with zero-mean normal distribution, scaled by a constant such that the Frobenius norm of each matrix is 5. We use weight decay of 0.01. We run SGD with batch size 128, constant learning rate 0.1 and momentum 0.9 for 1000 epochs. No weight decay. No data augmentation. Every input to the network is scaled such that it has Frobenius norm 1.
  • Figure 2: ConvNet with Batch Normalization and Weight Decay Dynamics of $\rho$ from experiments in Figure \ref{['appendix:fig:conv4_BN_wd_0.01:training_val']}. First row: small initialization (0.1). Second row: medium initialization (1). Third row: large initialization (5). A dashed rectangle denotes the previous subplot's domain and range in the new subplot.
  • Figure 3: ConvNet with Batch Normalization and Weight Decay Dynamics of the average of $|f_n|$ from experiments in Figure \ref{['appendix:fig:conv4_BN_wd_0.01:training_val']}. First row: small initialization (0.1). Second row: medium initialization (1). Third row: large initialization (5). A dashed rectangle denotes the previous subplot's domain and range in the new subplot.
  • Figure 4: ConvNet with Batch Normalization and Weight Decay Margin of all training samples.
  • Figure 5: ConvNet with Batch Normalization but no Weight Decay Binary classification on two classes from CIFAR-10, trained with MSE loss. The model is a very simple network with 4 layers of convolutions. ReLU nonlinearity is used. Batch normalization is used without parameters (affine=False in PyTorch). The weight matrices of all layers are initialized with zero-mean normal distribution, scaled by a constant such that the Frobenius norm of each matrix is either 0.1 or 5. We run SGD with batch size 128, constant learning rate 0.01 and momentum 0.9 for 1000 epochs. No data augmentation. Every input to the network is scaled such that it has Frobenius norm 1. This is a single run but it is typical for the parameter values we used.
  • ...and 6 more figures

Theorems & Definitions (1)

  • Lemma 1