Table of Contents
Fetching ...

Pay Attention to Small Weights

Chao Zhou, Tom Jacobs, Advait Gadhikar, Rebekka Burkholz

TL;DR

NanoAdam tackles the high memory and compute costs of finetuning large pretrained models by exploiting a consistent gradient–weight relationship observed during finetuning, where large gradients tend to occur on small-magnitude weights. It introduces a gradient-free, per-layer bottom-$k$ masking strategy with a density scheduler to update only small weights, enabling larger effective learning rates and reducing memory. Theoretical analysis in a two-layer teacher–student model shows updating small weights preserves the original representation and mitigates catastrophic forgetting, while empirical results on NLP (GLUE) and CV (CIFAR-10, Flowers102) show improved generalization and smaller parameter drift compared with baselines. This approach scales to large models and provides memory-efficient continual learning benefits across NLP and vision tasks.

Abstract

Finetuning large pretrained neural networks is known to be resource-intensive, both in terms of memory and computational cost. To mitigate this, a common approach is to restrict training to a subset of the model parameters. By analyzing the relationship between gradients and weights during finetuning, we observe a notable pattern: large gradients are often associated with small-magnitude weights. This correlation is more pronounced in finetuning settings than in training from scratch. Motivated by this observation, we propose NANOADAM, which dynamically updates only the small-magnitude weights during finetuning and offers several practical advantages: first, this criterion is gradient-free -- the parameter subset can be determined without gradient computation; second, it preserves large-magnitude weights, which are likely to encode critical features learned during pretraining, thereby reducing the risk of catastrophic forgetting; thirdly, it permits the use of larger learning rates and consistently leads to better generalization performance in experiments. We demonstrate this for both NLP and vision tasks.

Pay Attention to Small Weights

TL;DR

NanoAdam tackles the high memory and compute costs of finetuning large pretrained models by exploiting a consistent gradient–weight relationship observed during finetuning, where large gradients tend to occur on small-magnitude weights. It introduces a gradient-free, per-layer bottom- masking strategy with a density scheduler to update only small weights, enabling larger effective learning rates and reducing memory. Theoretical analysis in a two-layer teacher–student model shows updating small weights preserves the original representation and mitigates catastrophic forgetting, while empirical results on NLP (GLUE) and CV (CIFAR-10, Flowers102) show improved generalization and smaller parameter drift compared with baselines. This approach scales to large models and provides memory-efficient continual learning benefits across NLP and vision tasks.

Abstract

Finetuning large pretrained neural networks is known to be resource-intensive, both in terms of memory and computational cost. To mitigate this, a common approach is to restrict training to a subset of the model parameters. By analyzing the relationship between gradients and weights during finetuning, we observe a notable pattern: large gradients are often associated with small-magnitude weights. This correlation is more pronounced in finetuning settings than in training from scratch. Motivated by this observation, we propose NANOADAM, which dynamically updates only the small-magnitude weights during finetuning and offers several practical advantages: first, this criterion is gradient-free -- the parameter subset can be determined without gradient computation; second, it preserves large-magnitude weights, which are likely to encode critical features learned during pretraining, thereby reducing the risk of catastrophic forgetting; thirdly, it permits the use of larger learning rates and consistently leads to better generalization performance in experiments. We demonstrate this for both NLP and vision tasks.

Paper Structure

This paper contains 66 sections, 3 theorems, 10 equations, 19 figures, 28 tables, 1 algorithm.

Key Result

Theorem 2.2

Assume a model $f(x)$ consisting of $n$ neurons learns the teacher $f_{\text{teacher}}(x)$ corresponding to a pre-training task so that $f(x) = f_{\text{teacher}}(x)$ for all $x\in \mathbb{R}^d$. Furthermore, let $f(x)$ consist of at least two neurons $i,r \in [n]$ such that $\max\{|a_i|^2, |a_r|^2\

Figures (19)

  • Figure 1: The relationship between gradients and weights during FT and training from scratch. The x-axis represents the magnitude of the weights, while the y-axis represents the magnitude of the gradients. From left to right, the subfigures correspond to the FT NLP task, FT CV task, training CV task from scratch at early step and training CV task from scratch at later step.
  • Figure 2: Overlap between small weights and large gradients.
  • Figure 3: Nano gradient descent provably prevents catastrophic forgetting. (a) Nano gradient descent keeps the original representation while learning the extra neuron. (b) The largest gradients can correspond to weights with large magnitudes leading to unlearning of the original representation and the inability of learning the new representation.
  • Figure 4: Generalization performance of different masking strategies in NanoAdam using the same gradient density. (a) Small vs. large vs. random weights. (b) Small weights vs. large gradients. Small-weight masking achieves the best generalization performance.
  • Figure 5: The dynamic of the relationship between gradients and weights during finetuning Bert-base on COLA. The x-axis represents the magnitude of the weights, while the y-axis represents the magnitude of the gradients. From left to right, the subfigures correspond to the early, middle, and late stages of finetuning. From top to bottom, the subfigures represent progressively deeper layers in the network.
  • ...and 14 more figures

Theorems & Definitions (4)

  • Definition 2.1
  • Theorem 2.2
  • Theorem E.1
  • Lemma E.2