Table of Contents
Fetching ...

Masks, Signs, And Learning Rate Rewinding

Advait Gadhikar, Rebekka Burkholz

TL;DR

This work addresses how Learning Rate Rewinding (LRR) improves sparse network training by decoupling mask identification from parameter optimization. The authors develop a theoretical analysis of gradient-flow dynamics for a minimal two-layer ReLU model, illustrating that LRR can inherit advantageous weight signs from an overparameterized phase and thus identify the ground-truth mask more reliably than Iterative Magnitude Pruning (IMP). They prove that for a single hidden neuron, LRR is more likely to converge to the target under reasonable initializations, and that overparameterization (higher input dimension) enables sign switches that further bolster learning. Empirically, the study validates these insights on CIFAR-10/100, Tiny ImageNet, and ImageNet with ResNet architectures, showing that LRR outperforms IMP across masks (including random ones) and sparsities, especially when BN rewinding and warmup are employed. The results suggest that preserving and propagating sign information through pruning iterations yields a more robust sparse-training paradigm and could guide the design of practical sparse training algorithms that operate from scratch.

Abstract

Learning Rate Rewinding (LRR) has been established as a strong variant of Iterative Magnitude Pruning (IMP) to find lottery tickets in deep overparameterized neural networks. While both iterative pruning schemes couple structure and parameter learning, understanding how LRR excels in both aspects can bring us closer to the design of more flexible deep learning algorithms that can optimize diverse sets of sparse architectures. To this end, we conduct experiments that disentangle the effect of mask learning and parameter optimization and how both benefit from overparameterization. The ability of LRR to flip parameter signs early and stay robust to sign perturbations seems to make it not only more effective in mask identification but also in optimizing diverse sets of masks, including random ones. In support of this hypothesis, we prove in a simplified single hidden neuron setting that LRR succeeds in more cases than IMP, as it can escape initially problematic sign configurations.

Masks, Signs, And Learning Rate Rewinding

TL;DR

This work addresses how Learning Rate Rewinding (LRR) improves sparse network training by decoupling mask identification from parameter optimization. The authors develop a theoretical analysis of gradient-flow dynamics for a minimal two-layer ReLU model, illustrating that LRR can inherit advantageous weight signs from an overparameterized phase and thus identify the ground-truth mask more reliably than Iterative Magnitude Pruning (IMP). They prove that for a single hidden neuron, LRR is more likely to converge to the target under reasonable initializations, and that overparameterization (higher input dimension) enables sign switches that further bolster learning. Empirically, the study validates these insights on CIFAR-10/100, Tiny ImageNet, and ImageNet with ResNet architectures, showing that LRR outperforms IMP across masks (including random ones) and sparsities, especially when BN rewinding and warmup are employed. The results suggest that preserving and propagating sign information through pruning iterations yields a more robust sparse-training paradigm and could guide the design of practical sparse training algorithms that operate from scratch.

Abstract

Learning Rate Rewinding (LRR) has been established as a strong variant of Iterative Magnitude Pruning (IMP) to find lottery tickets in deep overparameterized neural networks. While both iterative pruning schemes couple structure and parameter learning, understanding how LRR excels in both aspects can bring us closer to the design of more flexible deep learning algorithms that can optimize diverse sets of sparse architectures. To this end, we conduct experiments that disentangle the effect of mask learning and parameter optimization and how both benefit from overparameterization. The ability of LRR to flip parameter signs early and stay robust to sign perturbations seems to make it not only more effective in mask identification but also in optimizing diverse sets of masks, including random ones. In support of this hypothesis, we prove in a simplified single hidden neuron setting that LRR succeeds in more cases than IMP, as it can escape initially problematic sign configurations.
Paper Structure (13 sections, 4 theorems, 9 equations, 18 figures, 1 table)

This paper contains 13 sections, 4 theorems, 9 equations, 18 figures, 1 table.

Key Result

Theorem 2.1

Let a target $t(x) = \phi(x)$ and network $f(x) = a\phi(wx)$ be given such that $a$ and $w$ follow the gradient flow dynamics (eq:objective) with a random balanced parameter initialization and sufficiently many samples. If $a(0) > 0$ and $w(0) > 0$, $f(x)$ can learn the correct target. In all other

Figures (18)

  • Figure 1: (a) Target network. For one dimensional input, learning succeeds when the initial values $w(0), a(0) > 0$ are both positive (yellow quadrant), but fails in all other cases (red). (b) For multidimensional input, IMP identifies the correct mask, but cannot learn the target if the model is reinitialized to $w^{(2)}(0) < 0$. (c) LRR identifies the correct mask and is able to inherit the correct initial sign $w^{(2)}(0) > 0$ from the trained overparameterized model if $a^{(0)}(0) > 0$.
  • Figure 2: (a) IMP and LRR for a single hidden neuron network. (b, c) Mask randomization for (b) CIFAR10 and (c) CIFAR100. (d) LRR optimizes the IMP mask more effectively on Tiny ImageNet.
  • Figure 3: The sparse mask learnt by LRR is superior and the performance of IMP is improved in combination with the LRR mask on (a, b) CIFAR10 and (c, d) CIFAR100.
  • Figure 4: LRR improves parameter optimization within the mask learnt by IMP for (a, b) CIFAR10 and (c, d) CIFAR100.
  • Figure 5: (top) The pruning iteration at which the parameter signs do not change anymore for LRR (purple) is much earlier than IMP (orange). (bottom) The number of times a parameter switches sign over pruning iterations (a) CIFAR10 (b) CIFAR100 and (c) Tiny ImageNet.
  • ...and 13 more figures

Theorems & Definitions (6)

  • Theorem 2.1
  • Lemma 2.2
  • Theorem 2.3
  • Theorem A.1
  • proof
  • proof