Table of Contents
Fetching ...

DASH: Warm-Starting Neural Network Training in Stationary Settings without Loss of Plasticity

Baekrok Shin, Junsoo Oh, Hanseul Cho, Chulhee Yun

TL;DR

DASH is proposed, a method aiming to mitigate plasticity loss by selectively forgetting memorized noise while preserving learned features and it is validated on vision tasks, demonstrating improvements in test accuracy and training efficiency.

Abstract

Warm-starting neural network training by initializing networks with previously learned weights is appealing, as practical neural networks are often deployed under a continuous influx of new data. However, it often leads to loss of plasticity, where the network loses its ability to learn new information, resulting in worse generalization than training from scratch. This occurs even under stationary data distributions, and its underlying mechanism is poorly understood. We develop a framework emulating real-world neural network training and identify noise memorization as the primary cause of plasticity loss when warm-starting on stationary data. Motivated by this, we propose Direction-Aware SHrinking (DASH), a method aiming to mitigate plasticity loss by selectively forgetting memorized noise while preserving learned features. We validate our approach on vision tasks, demonstrating improvements in test accuracy and training efficiency.

DASH: Warm-Starting Neural Network Training in Stationary Settings without Loss of Plasticity

TL;DR

DASH is proposed, a method aiming to mitigate plasticity loss by selectively forgetting memorized noise while preserving learned features and it is validated on vision tasks, demonstrating improvements in test accuracy and training efficiency.

Abstract

Warm-starting neural network training by initializing networks with previously learned weights is appealing, as practical neural networks are often deployed under a continuous influx of new data. However, it often leads to loss of plasticity, where the network loses its ability to learn new information, resulting in worse generalization than training from scratch. This occurs even under stationary data distributions, and its underlying mechanism is poorly understood. We develop a framework emulating real-world neural network training and identify noise memorization as the primary cause of plasticity loss when warm-starting on stationary data. Motivated by this, we propose Direction-Aware SHrinking (DASH), a method aiming to mitigate plasticity loss by selectively forgetting memorized noise while preserving learned features. We validate our approach on vision tasks, demonstrating improvements in test accuracy and training efficiency.

Paper Structure

This paper contains 42 sections, 7 theorems, 23 equations, 23 figures, 10 tables, 5 algorithms.

Key Result

Theorem 3.4

There exists nonempty ${\mathcal{G}} \subsetneq {\mathcal{S}}$ such that we always obtain ${\mathcal{L}}_\mathrm{warm}^{(1)} = {\mathcal{L}}_\mathrm{cold}^{(1)} = {\mathcal{G}}$. For all $J \geq 2$, the following inequalities hold: Furthermore, $\mathrm{ACC}({\mathcal{L}}_{\rm warm}^{(J)}) < \mathrm{ACC}({\mathcal{L}}_{\rm cold}^{(J)})$ holds when $J > \frac{\gamma}{\delta n}$ where $\delta \tria

Figures (23)

  • Figure 1: Performance comparison of various methods on Tiny-ImageNet using ResNet-18. The same hyperparameters are used across all methods. The dataset is divided into 50 chunks, with a constant number of data points added to the training dataset in each experiment (x-axis), reaching the full dataset at the 50th experiment. Models are trained until achieving 99.9% train accuracy before proceeding to the next experiment; the plot on the right reports the number of update steps executed in each experiment. Results are averaged over three random seeds. "Cold" refers to cold-starting and "Warm" refers to warm-starting. The Shrink & Perturb (S&P) method involves shrinking the model weights by a constant factor and adding noise ash2020warm. Notably, DASH, our proposed method, achieves better generalization performance compared to both training from scratch and S&P, while requiring fewer steps to converge.
  • Figure 2: The plot shows the test accuracy (left y-axis) when the model is pretrained for varying epochs (x-axis) and then fine-tuned on the full data, along with the pretrain accuracy (right y-axis) plotted in brown. We trained three-layer MLP (left) and ResNet-18 (right). Each transparent line and point corresponds to a specific random seed, and the median values are highlighted with opaque markers and solid lines. The 'Random' corresponds to training from random initialization (cold-start).
  • Figure 3: Comparison of random, warm, and ideal methods across 10 random seeds (mean ± std dev). The test accuracy (left) and the number of learned features across all classes (middle) are nearly identical for random and ideal initializations, causing their plots to overlap. Warm initialization, however, exhibits lower test accuracy compared to both methods. Regarding training time (right), there is a significant gap between random and warm initialization, which the ideal method addresses.
  • Figure 4: Illustration of DASH. We compute the loss $L$ with training data ${\mathcal{T}}_{1:j}$ and obtain the negative gradient. Then, we shrink the weights proportionally to the cosine similarity between the current weight $\theta$ and $\nabla_{\theta} L$, resulting in $\Tilde{\theta}$.
  • Figure 5: An illustration of our proposed feature learning framework with single class. Each data point is composed of features, as shown in Figure \ref{['fig:dash_framework_data']}, and is learned through the framework depicted in Figure \ref{['fig:dash_framework_model']}.
  • ...and 18 more figures

Theorems & Definitions (16)

  • Remark 2.1
  • Remark 3.2
  • Theorem 3.4
  • proof : Proof Idea
  • Remark 3.5
  • Theorem 3.6
  • Lemma D.1
  • proof : Proof of Lemma \ref{['lemma:test_acc']}
  • Lemma D.2
  • proof : Proof of Lemma \ref{['lemma:order']}
  • ...and 6 more