Table of Contents
Fetching ...

Improving robustness to corruptions with multiplicative weight perturbations

Trung Trinh, Markus Heinonen, Luigi Acerbi, Samuel Kaski

TL;DR

This work introduces Data Augmentation via Multiplicative Perturbations (DAMP), a training method that perturbs neural network weights multiplicatively with random noise to simulate input corruptions and improve robustness across a wide range of distortions. It establishes a theoretical link between input perturbations and weight perturbations, and connects Adaptive Sharpness-Aware Minimization (ASAM) to adversarial multiplicative perturbations, showing similar update dynamics. Empirically, DAMP improves corruption robustness across CIFAR-10/100, TinyImageNet, and ImageNet on ResNet and Vision Transformers, including training ViT-S/16 from scratch on ImageNet with competitive results and in some cases surpassing more expensive methods like SAM/ASAM. DAMP can be combined with modern data augmentations (e.g., MixUp, RandAugment) and maintains comparable training cost to standard SGD, offering a practical, scalable approach to robustness in real-world vision systems.

Abstract

Deep neural networks (DNNs) excel on clean images but struggle with corrupted ones. Incorporating specific corruptions into the data augmentation pipeline can improve robustness to those corruptions but may harm performance on clean images and other types of distortion. In this paper, we introduce an alternative approach that improves the robustness of DNNs to a wide range of corruptions without compromising accuracy on clean images. We first demonstrate that input perturbations can be mimicked by multiplicative perturbations in the weight space. Leveraging this, we propose Data Augmentation via Multiplicative Perturbation (DAMP), a training method that optimizes DNNs under random multiplicative weight perturbations. We also examine the recently proposed Adaptive Sharpness-Aware Minimization (ASAM) and show that it optimizes DNNs under adversarial multiplicative weight perturbations. Experiments on image classification datasets (CIFAR-10/100, TinyImageNet and ImageNet) and neural network architectures (ResNet50, ViT-S/16, ViT-B/16) show that DAMP enhances model generalization performance in the presence of corruptions across different settings. Notably, DAMP is able to train a ViT-S/16 on ImageNet from scratch, reaching the top-1 error of 23.7% which is comparable to ResNet50 without extensive data augmentations.

Improving robustness to corruptions with multiplicative weight perturbations

TL;DR

This work introduces Data Augmentation via Multiplicative Perturbations (DAMP), a training method that perturbs neural network weights multiplicatively with random noise to simulate input corruptions and improve robustness across a wide range of distortions. It establishes a theoretical link between input perturbations and weight perturbations, and connects Adaptive Sharpness-Aware Minimization (ASAM) to adversarial multiplicative perturbations, showing similar update dynamics. Empirically, DAMP improves corruption robustness across CIFAR-10/100, TinyImageNet, and ImageNet on ResNet and Vision Transformers, including training ViT-S/16 from scratch on ImageNet with competitive results and in some cases surpassing more expensive methods like SAM/ASAM. DAMP can be combined with modern data augmentations (e.g., MixUp, RandAugment) and maintains comparable training cost to standard SGD, offering a practical, scalable approach to robustness in real-world vision systems.

Abstract

Deep neural networks (DNNs) excel on clean images but struggle with corrupted ones. Incorporating specific corruptions into the data augmentation pipeline can improve robustness to those corruptions but may harm performance on clean images and other types of distortion. In this paper, we introduce an alternative approach that improves the robustness of DNNs to a wide range of corruptions without compromising accuracy on clean images. We first demonstrate that input perturbations can be mimicked by multiplicative perturbations in the weight space. Leveraging this, we propose Data Augmentation via Multiplicative Perturbation (DAMP), a training method that optimizes DNNs under random multiplicative weight perturbations. We also examine the recently proposed Adaptive Sharpness-Aware Minimization (ASAM) and show that it optimizes DNNs under adversarial multiplicative weight perturbations. Experiments on image classification datasets (CIFAR-10/100, TinyImageNet and ImageNet) and neural network architectures (ResNet50, ViT-S/16, ViT-B/16) show that DAMP enhances model generalization performance in the presence of corruptions across different settings. Notably, DAMP is able to train a ViT-S/16 on ImageNet from scratch, reaching the top-1 error of 23.7% which is comparable to ResNet50 without extensive data augmentations.

Paper Structure

This paper contains 47 sections, 2 theorems, 31 equations, 7 figures, 3 tables, 3 algorithms.

Key Result

Lemma 1

For all $h=1,\dots,H$ and for all $\mathbf{x} \in \mathcal{X}$, there exists a scalar $C_\mathbf{g}^{(h)}(\mathbf{x}) > 0$ such that:

Figures (7)

  • Figure 1: Depictions of a pre-activation neuron $z=\mathbf{w}^\top\mathbf{x}$ in the presence of (a) covariate shift ${\color{fig1red}\boldsymbol{\epsilon}}$, (b) a multiplicative weight perturbation (MWP) equivalent to ${\color{fig1red}\boldsymbol{\epsilon}}$, and (c) random MWPs ${\color{fig1green}\boldsymbol{\xi}}$.$\circ$ denotes the Hadamard product. Figs. (a) and (b) show that for a covariate shift ${\color{fig1red}\boldsymbol{\epsilon}}$, one can always find an equivalent MWP. From this intuition, we propose to inject random MWPs ${\color{fig1green}\boldsymbol{\xi}}$ to the forward pass during training as shown in Fig. (c) to robustify a DNN to covariate shift.
  • Figure 2: Depiction of how a corruption $\mathbf{g}$ affects the output of a DNN. Here $\mathbf{x}_\mathbf{g}=\mathbf{g}(\mathbf{x})$. The corruption $\mathbf{g}$ creates a shift $\boldsymbol{\delta}_\mathbf{g}\mathbf{x}^{(0)}=\mathbf{x}_\mathbf{g}-\mathbf{x}$ in the input $\mathbf{x}$, which propagates into shifts $\boldsymbol{\delta}_\mathbf{g}\mathbf{x}^{(h)}$ in the output of each layer. This will eventually cause a shift in the loss $\boldsymbol{\delta}_\mathbf{g} \ell$. This figure explains why the model performance tends to degrade under corruption.
  • Figure 3: DAMP improves robustness to all corruptions while preserving accuracy on clean images. Results of ResNet18/CIFAR-100 experiments averaged over 5 seeds. The heatmap shows $\mathrm{CE}^{f}_c$ described in \ref{['eq:ce_c_f']} (lower is better), where each row corresponds to a tuple of training (method, corruption), while each column corresponds to the test corruption. The Avg column shows the average of the results of the previous columns. none indicates no corruption. We use the models trained under the SGD/none setting (first row) as baselines to calculate the $\mathrm{CE}^{f}_c$. The last five rows are the 5 best training corruptions ranked by the results in the Avg column.
  • Figure 4: DAMP surpasses SAM on corrupted images in most cases, despite requiring only half the training cost. We report the predictive errors (lower is better) averaged over 5 seeds. A severity level of $0$ indicates no corruption. We use the same number of epochs for all methods.
  • Figure 5: DAMP improves robustness to all corruptions while preserving accuracy on clean images. Results of ResNet18/CIFAR-10 experiments averaged over 3 seeds. The heatmap shows $\mathrm{CE}^{f}_c$ described in \ref{['eq:ce_c_f']}, where each row corresponds to a tuple of of training (method, corruption), while each column corresponds to the test corruption. The Avg column shows the average of the results of the previous columns. none indicates no corruption. We use the models trained under the SGD/none setting (first row) as baselines to calculate the $\mathrm{CE}^{f}_c$. The last five rows are the 5 best training corruptions ranked by the results in the Avg column.
  • ...and 2 more figures

Theorems & Definitions (4)

  • Lemma 1
  • Theorem 1
  • proof
  • proof