Table of Contents
Fetching ...

Decoupled Weight Decay for Any $p$ Norm

Nadav Joseph Outmezguine, Noam Levi

TL;DR

This work addresses the resource-intensive training of large neural networks by introducing a decoupled weight decay scheme for $L_p$ regularization, enabling highly sparse models while preserving generalization. It derives a bi-convex reformulation with auxiliary variables and presents the proximal-gradient based $p$-norm Weight Decay ($p$WD) update: $w \leftarrow (w - \alpha \nabla \mathcal{L})/(1 + \alpha \lambda_p |w|^{p-2})$ with $s = |w|^{p-2}$, integrating smoothly with adaptive optimizers. Empirically, $p$WD achieves extremely high sparsity (up to ~99.5%) on CIFAR-10 and Tiny Shakespeare with accuracy rivaling AdamW, while revealing that sparsity is strongest for $p<1$ and generalization peaks around $1<p<2$; the study also discusses limitations and extensions such as $s$-dynamics, $p$-scheduling, and elastic-net hybrids. Overall, $p$WD provides a practical, low-overhead path to sparse training, offering comparable performance to state-of-the-art pruning methods and enabling energy- and memory-efficient deployment with potential for broader optimization contexts.

Abstract

With the success of deep neural networks (NNs) in a variety of domains, the computational and storage requirements for training and deploying large NNs have become a bottleneck for further improvements. Sparsification has consequently emerged as a leading approach to tackle these issues. In this work, we consider a simple yet effective approach to sparsification, based on the Bridge, or $L_p$ regularization during training. We introduce a novel weight decay scheme, which generalizes the standard $L_2$ weight decay to any $p$ norm. We show that this scheme is compatible with adaptive optimizers, and avoids the gradient divergence associated with $0<p<1$ norms. We empirically demonstrate that it leads to highly sparse networks, while maintaining generalization performance comparable to standard $L_2$ regularization.

Decoupled Weight Decay for Any $p$ Norm

TL;DR

This work addresses the resource-intensive training of large neural networks by introducing a decoupled weight decay scheme for regularization, enabling highly sparse models while preserving generalization. It derives a bi-convex reformulation with auxiliary variables and presents the proximal-gradient based -norm Weight Decay (WD) update: with , integrating smoothly with adaptive optimizers. Empirically, WD achieves extremely high sparsity (up to ~99.5%) on CIFAR-10 and Tiny Shakespeare with accuracy rivaling AdamW, while revealing that sparsity is strongest for and generalization peaks around ; the study also discusses limitations and extensions such as -dynamics, -scheduling, and elastic-net hybrids. Overall, WD provides a practical, low-overhead path to sparse training, offering comparable performance to state-of-the-art pruning methods and enabling energy- and memory-efficient deployment with potential for broader optimization contexts.

Abstract

With the success of deep neural networks (NNs) in a variety of domains, the computational and storage requirements for training and deploying large NNs have become a bottleneck for further improvements. Sparsification has consequently emerged as a leading approach to tackle these issues. In this work, we consider a simple yet effective approach to sparsification, based on the Bridge, or regularization during training. We introduce a novel weight decay scheme, which generalizes the standard weight decay to any norm. We show that this scheme is compatible with adaptive optimizers, and avoids the gradient divergence associated with norms. We empirically demonstrate that it leads to highly sparse networks, while maintaining generalization performance comparable to standard regularization.
Paper Structure (22 sections, 3 theorems, 30 equations, 5 figures, 2 tables, 1 algorithm)

This paper contains 22 sections, 3 theorems, 30 equations, 5 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1.1

Let ${\cal L}:\;\mathbb{R}^n\to\mathbb{R}$, let $R:\;\mathbb{R}^n\times\mathbb{R}^n\to\mathbb{R}$ be a smooth bi-convex function. Define $F(\boldsymbol{w},\boldsymbol{s})={\cal L}(\boldsymbol{w})+R(\boldsymbol{w},\boldsymbol{s})$, and $\hat{s}(\cdot)=\underset{\boldsymbol{s}}{\operatorname{argmin}}\

Figures (5)

  • Figure 1: Toy example of weight evolution under gradient descent for the loss ${\cal L}=(w-1)^2/2+\left\Vert w \right\Vert_p^p/p$. Dotted line: represents simple gradient descent where the norm is added directly to the gradient. The weight fails to converge to 0 due to the exploding gradient of the $p$-norm near 0. The dashed line represents the evolution of the weight under the update rule in \ref{['eq:proximal']}, where we update $s$ every 20 $w$ steps. The solid line represents the evolution of the weight under the update rule in \ref{['eq:proximal']}, where we update $s$ at every $w$ step. The latter is an implementation of $p$-norm Weight Decay ($p$WD). We see that in both implementations of our method, the weight converges smoothly to 0.
  • Figure 2: Validation accuracy vs. sparsity for ResNet18 trained on CIFAR-10. Each point represents a different instance of the network trained for 100 epochs, with a different choice of $p$, $\lambda_p$, and learning rate $\alpha$. Points of different colors indicate different choices of $p$, optimizing over $\lambda_p, \alpha$. The dashed-red line indicate the best accuracy achieved using AdamW. The orange stars indicate the best accuracy runs obtained using Only Train OnceOTO. The green crosses indicate the best accuracies obtained using iterative magnitude pruning. Left: Validation accuracy vs. sparsity. Right: Example of the accuracy/sparsity trade-off given in \ref{['eq:tradeoff']}.
  • Figure 3: Validation accuracy vs. sparsity for nanoGPT trained on Tiny Shakespeare. Each point represents a different instance of the network trained for 5000 iterations, with a different choice of $p$, $\lambda_p$, and learning rate $\alpha$. Points of different colors indicate different choices of $p$, optimizing over $\lambda_p, \alpha$. The dashed-red line indicates the best accuracy achieved using AdamW. The green crosses indicate the best accuracies obtained using iterative magnitude pruning. Left: Validation accuracy vs. sparsity. Right: Example of the accuracy/sparsity trade-off given in \ref{['eq:tradeoff']}.
  • Figure 4: Contours of validation accuracy after 100 training epochs on the $\lambda_p$ vs. learning rate plane, for ResNet18 on CIFAR-10. White contours represent the [0.01, 0.2, 0.4, 0.8] sparisty levlel.
  • Figure 5: Contours of validation accuracy after 5000 training iterations on the $\lambda_p$ vs. learning rate plane, for nanoGPT on Tiny Shakespeare. White contours represent the [0.01, 0.2, 0.4, 0.8] sparisty levlel.

Theorems & Definitions (6)

  • Theorem 1.1
  • proof
  • Lemma 3.1
  • proof
  • Theorem 3.2
  • proof