Table of Contents
Fetching ...

Weight Decay Improves Language Model Plasticity

Tessa Han, Sebastian Bordt, Hanlin Zhang, Sham Kakade

TL;DR

This work investigates how weight decay during pretraining influences language model plasticity—the ability to adapt to downstream tasks. By systematically varying the weight decay parameter across Llama-2 and OLMo-2 models and evaluating downstream performance after fine-tuning, the authors show that larger weight decay often enhances downstream adaptability even when pretraining loss worsens, highlighting that pretraining loss is not a reliable sole predictor of downstream success. They reveal mechanistic effects of weight decay, including more linearly separable representations, reduced attention-matrix rank, and diminished pretraining overfitting, which together help explain improved plasticity. The study argues for incorporating downstream objectives into hyperparameter optimization and provides a nuanced view of weight decay’s multifaceted role in shaping model behavior across training stages.

Abstract

The prevailing paradigm in large language model (LLM) development is to pretrain a base model, then perform further training to improve performance and model behavior. However, hyperparameter optimization and scaling laws have been studied primarily from the perspective of the base model's validation loss, ignoring downstream adaptability. In this work, we study pretraining from the perspective of model plasticity, that is, the ability of the base model to successfully adapt to downstream tasks through fine-tuning. We focus on the role of weight decay, a key regularization parameter during pretraining. Through systematic experiments, we show that models trained with larger weight decay values are more plastic, meaning they show larger performance gains when fine-tuned on downstream tasks. This phenomenon can lead to counterintuitive trade-offs where base models that perform worse after pretraining can perform better after fine-tuning. Further investigation of weight decay's mechanistic effects on model behavior reveals that it encourages linearly separable representations, regularizes attention matrices, and reduces overfitting on the training data. In conclusion, this work demonstrates the importance of using evaluation metrics beyond cross-entropy loss for hyperparameter optimization and casts light on the multifaceted role of that a single optimization hyperparameter plays in shaping model behavior.

Weight Decay Improves Language Model Plasticity

TL;DR

This work investigates how weight decay during pretraining influences language model plasticity—the ability to adapt to downstream tasks. By systematically varying the weight decay parameter across Llama-2 and OLMo-2 models and evaluating downstream performance after fine-tuning, the authors show that larger weight decay often enhances downstream adaptability even when pretraining loss worsens, highlighting that pretraining loss is not a reliable sole predictor of downstream success. They reveal mechanistic effects of weight decay, including more linearly separable representations, reduced attention-matrix rank, and diminished pretraining overfitting, which together help explain improved plasticity. The study argues for incorporating downstream objectives into hyperparameter optimization and provides a nuanced view of weight decay’s multifaceted role in shaping model behavior across training stages.

Abstract

The prevailing paradigm in large language model (LLM) development is to pretrain a base model, then perform further training to improve performance and model behavior. However, hyperparameter optimization and scaling laws have been studied primarily from the perspective of the base model's validation loss, ignoring downstream adaptability. In this work, we study pretraining from the perspective of model plasticity, that is, the ability of the base model to successfully adapt to downstream tasks through fine-tuning. We focus on the role of weight decay, a key regularization parameter during pretraining. Through systematic experiments, we show that models trained with larger weight decay values are more plastic, meaning they show larger performance gains when fine-tuned on downstream tasks. This phenomenon can lead to counterintuitive trade-offs where base models that perform worse after pretraining can perform better after fine-tuning. Further investigation of weight decay's mechanistic effects on model behavior reveals that it encourages linearly separable representations, regularizes attention matrices, and reduces overfitting on the training data. In conclusion, this work demonstrates the importance of using evaluation metrics beyond cross-entropy loss for hyperparameter optimization and casts light on the multifaceted role of that a single optimization hyperparameter plays in shaping model behavior.
Paper Structure (25 sections, 4 equations, 19 figures, 7 tables)

This paper contains 25 sections, 4 equations, 19 figures, 7 tables.

Figures (19)

  • Figure 1: Pretraining validation cross-entropy loss of models pretrained with varying weight decay. The weight decay value that minimizes pretraining loss may be equal to or larger than the standard default value of 0.1 depending on the training regime.
  • Figure 2: Weight decay during pretraining improves language model plasticity and downstream performance. This figure plots the average accuracy after fine-tuning for models pretrained with varying weight decay. The results indicate that weight decay leads to better downstream model performance, suggesting it enables the pretrained model to learn better during fine-tuning and improves model plasticity. In these experiments, the optimal weight decay for downstream performance is larger than the standard default of 0.1. In addition, the optimal weight decay based on pretraining loss (Figure \ref{['fig:wd-vs-pt-loss']}) and that based on fine-tuning accuracy (this figure) are different, suggesting that optimizing hyperparameters based solely on pretraining loss does not always produce models with the best downstream performance.
  • Figure 3: A model's performance after pretraining is not perfectly predictive of its performance downstream. Models with similar pretraining losses can perform differently downstream, and models with lower validation cross-entropy loss after pretraining can perform better or worse downstream (i.e., after fine-tuning) than models with higher pretraining losses.
  • Figure 4: Weight decay encourages linearly separated representations. This figure depicts the accuracy of linear probes for sentiment and topic for models pretrained with different weight decay values. We observe that linear probing achieves better accuracy when models are pretrained with a weight decay greater than the default 0.1.
  • Figure 5: Weight decay reduces the rank of attention matrices. This figure depicts the average pseudo-rank (Appendix \ref{['app:pseudo_rank']}) of the query-key ($W_{QK}$) and value projection ($W_{VP}$) matrices in layers 5 and 15 during the training of OLMo-2-1B models at 20 TPP.
  • ...and 14 more figures