Table of Contents
Fetching ...

Towards Better Generalization: Weight Decay Induces Low-rank Bias for Neural Networks

Ke Chen, Chugang Yi, Haizhao Yang

TL;DR

This work analyzes why SGD with Weight Decay generalizes well by revealing a WD-induced low-rank bias in two-layer ReLU networks. It proves that, under mild training conditions, the coefficient matrix $V$ becomes close to a rank-$2$ form, with even tighter rank control under small batch gradients. By leveraging this bias, the authors derive improved generalization bounds, replacing the previous $ ext{O}\left(\sqrt{\frac{mn \ln m \ln N}{N}}\right)$ rate with a tighter $ ext{O}\left(\sqrt{\frac{(m+n) \ln m \ln N}{N}}\right)$ for rank-$2$ networks. Empirical results on California Housing and MNIST corroborate that WD drives $V$ toward low rank and can enhance generalization, supporting a theoretical mechanism for SGD's strong empirical performance.

Abstract

We study the implicit bias towards low-rank weight matrices when training neural networks (NN) with Weight Decay (WD). We prove that when a ReLU NN is sufficiently trained with Stochastic Gradient Descent (SGD) and WD, its weight matrix is approximately a rank-two matrix. Empirically, we demonstrate that WD is a necessary condition for inducing this low-rank bias across both regression and classification tasks. Our work differs from previous studies as our theoretical analysis does not rely on common assumptions regarding the training data distribution, optimality of weight matrices, or specific training procedures. Furthermore, by leveraging the low-rank bias, we derive improved generalization error bounds and provide numerical evidence showing that better generalization can be achieved. Thus, our work offers both theoretical and empirical insights into the strong generalization performance of SGD when combined with WD.

Towards Better Generalization: Weight Decay Induces Low-rank Bias for Neural Networks

TL;DR

This work analyzes why SGD with Weight Decay generalizes well by revealing a WD-induced low-rank bias in two-layer ReLU networks. It proves that, under mild training conditions, the coefficient matrix becomes close to a rank- form, with even tighter rank control under small batch gradients. By leveraging this bias, the authors derive improved generalization bounds, replacing the previous rate with a tighter for rank- networks. Empirical results on California Housing and MNIST corroborate that WD drives toward low rank and can enhance generalization, supporting a theoretical mechanism for SGD's strong empirical performance.

Abstract

We study the implicit bias towards low-rank weight matrices when training neural networks (NN) with Weight Decay (WD). We prove that when a ReLU NN is sufficiently trained with Stochastic Gradient Descent (SGD) and WD, its weight matrix is approximately a rank-two matrix. Empirically, we demonstrate that WD is a necessary condition for inducing this low-rank bias across both regression and classification tasks. Our work differs from previous studies as our theoretical analysis does not rely on common assumptions regarding the training data distribution, optimality of weight matrices, or specific training procedures. Furthermore, by leveraging the low-rank bias, we derive improved generalization error bounds and provide numerical evidence showing that better generalization can be achieved. Thus, our work offers both theoretical and empirical insights into the strong generalization performance of SGD when combined with WD.
Paper Structure (20 sections, 11 theorems, 42 equations, 9 figures, 2 tables)

This paper contains 20 sections, 11 theorems, 42 equations, 9 figures, 2 tables.

Key Result

Lemma 2.1

Consider a two-layer NN in Equation eqn:2layerNN, for any fixed $(x,b) \in \mathbb{R}^n \times \mathbb{R}^m$, where $\mathcal{V}^0 \in \mathbb{R}^{m\times n}$ is a measure zero set that depends on $x$ and $b$.

Figures (9)

  • Figure 1: California Housing Prices. Left: Stable rank $r_{\text{s}}(V)$ versus $\mu_V$. Right: Singular values of $V$ for $\mu_V=0.0001$ and $1$. Here we fix batch size $B=16$.
  • Figure 2: California Housing Prices. Left: Stable rank. Middle: Training MSE. Right: Absolute value of generalization error. Here we fix the batch size $B=16$. The sharp transition in the generalization error happens when it changes sign.
  • Figure 3: California Housing Prices. Histograms of the Frobenius norm of all batches in the final epoch.
  • Figure 4: MNIST. Left: Stable rank versus $\mu_V$. Right: Singular values of $V$ for $\mu_V=10^{-5}$ and $1$. Here we fix batch size $B=64$.
  • Figure 5: MNIST. Left: Stable rank. Middle: Training MSE. Right: Absolute value of generalization error. Here we fix the batch size $B=64$. The sharp transition in the generalization error happens when it changes sign.
  • ...and 4 more figures

Theorems & Definitions (21)

  • Lemma 2.1
  • Lemma 2.2
  • Theorem 2.4
  • Theorem 2.6
  • Definition 3.1: Covering Number
  • Definition 3.2: Uniform Covering Number
  • Theorem 3.3: Theorem 17.1 in anthony1999neural
  • Definition 3.4: Pseudo-dimension
  • Theorem 3.5: Theorem 18.4 in anthony1999neural
  • Proposition 3.6
  • ...and 11 more