Table of Contents
Fetching ...

Disentangling the Causes of Plasticity Loss in Neural Networks

Clare Lyle, Zeyu Zheng, Khimya Khetarpal, Hado van Hasselt, Razvan Pascanu, James Martens, Will Dabney

TL;DR

The paper addresses why neural networks lose plasticity under nonstationary data, particularly in reinforcement learning. It shows that plasticity loss arises from several independent mechanisms and that their effects can be diagnosed via empirical NTK patterns. By proposing a Swiss-cheese mitigation that combines layer normalization with weight decay, the authors demonstrate additive improvements across synthetic nonstationary tasks, Atari, DM Control, and natural distribution shifts. This modular approach reduces the complexity of stabilizing optimization in nonstationary settings and points to practical strategies for maintaining trainability without frequent parameter resets.

Abstract

Underpinning the past decades of work on the design, initialization, and optimization of neural networks is a seemingly innocuous assumption: that the network is trained on a \textit{stationary} data distribution. In settings where this assumption is violated, e.g.\ deep reinforcement learning, learning algorithms become unstable and brittle with respect to hyperparameters and even random seeds. One factor driving this instability is the loss of plasticity, meaning that updating the network's predictions in response to new information becomes more difficult as training progresses. While many recent works provide analyses and partial solutions to this phenomenon, a fundamental question remains unanswered: to what extent do known mechanisms of plasticity loss overlap, and how can mitigation strategies be combined to best maintain the trainability of a network? This paper addresses these questions, showing that loss of plasticity can be decomposed into multiple independent mechanisms and that, while intervening on any single mechanism is insufficient to avoid the loss of plasticity in all cases, intervening on multiple mechanisms in conjunction results in highly robust learning algorithms. We show that a combination of layer normalization and weight decay is highly effective at maintaining plasticity in a variety of synthetic nonstationary learning tasks, and further demonstrate its effectiveness on naturally arising nonstationarities, including reinforcement learning in the Arcade Learning Environment.

Disentangling the Causes of Plasticity Loss in Neural Networks

TL;DR

The paper addresses why neural networks lose plasticity under nonstationary data, particularly in reinforcement learning. It shows that plasticity loss arises from several independent mechanisms and that their effects can be diagnosed via empirical NTK patterns. By proposing a Swiss-cheese mitigation that combines layer normalization with weight decay, the authors demonstrate additive improvements across synthetic nonstationary tasks, Atari, DM Control, and natural distribution shifts. This modular approach reduces the complexity of stabilizing optimization in nonstationary settings and points to practical strategies for maintaining trainability without frequent parameter resets.

Abstract

Underpinning the past decades of work on the design, initialization, and optimization of neural networks is a seemingly innocuous assumption: that the network is trained on a \textit{stationary} data distribution. In settings where this assumption is violated, e.g.\ deep reinforcement learning, learning algorithms become unstable and brittle with respect to hyperparameters and even random seeds. One factor driving this instability is the loss of plasticity, meaning that updating the network's predictions in response to new information becomes more difficult as training progresses. While many recent works provide analyses and partial solutions to this phenomenon, a fundamental question remains unanswered: to what extent do known mechanisms of plasticity loss overlap, and how can mitigation strategies be combined to best maintain the trainability of a network? This paper addresses these questions, showing that loss of plasticity can be decomposed into multiple independent mechanisms and that, while intervening on any single mechanism is insufficient to avoid the loss of plasticity in all cases, intervening on multiple mechanisms in conjunction results in highly robust learning algorithms. We show that a combination of layer normalization and weight decay is highly effective at maintaining plasticity in a variety of synthetic nonstationary learning tasks, and further demonstrate its effectiveness on naturally arising nonstationarities, including reinforcement learning in the Arcade Learning Environment.
Paper Structure (38 sections, 5 equations, 28 figures)

This paper contains 38 sections, 5 equations, 28 figures.

Figures (28)

  • Figure 1: Left. Illustration of the relationship between pretraining target magnitude and optimization speed on a new task. We see a strong dose-response effect from increasing the magnitude of regression targets on the final loss on a fine-tuning task, and observe similar trends in the learning curves on these tasks as those observed by lyle2023understanding in DQN agents trained on contextual bandits. Right. Dose-response curves for the effect of target offset scale and the magnitude of the distribution shift on plasticity. We see that more severe distribution shifts, e.g. randomizing an entire dataset simultaneously rather than gradually, results in more extreme loss of plasticity.
  • Figure 2: Accumulation of dead units after a task change: we visualize the learning dynamics of a small MLP trained to memorize random labels of the MNIST dataset immediately after a task change. Left. We highlight an early phase of training wherein predictive entropy sharply increases along with the number of dead units. Middle. The distribution shift in the pre-activations in this early phase, and the concomitant spike in gradients with negative dot product with all incoming features at the start of learning, contrasting with the stable dynamics later in training. Right. Similar spikes in linearized units occur immediately after a task change in convolutional networks trained with different activation functions on the CIFAR=10 dataset.
  • Figure 3: Left. Visualization of the effect of parameter scale on learning. Right. Visualization of the empirical NTKs of a variety of networks which exhibit varying degrees of plasticity loss, taken from previously-studied scenarios. We observe greater interference in networks which have lost plasticity than in a random initialization (i.e. greater magnitude of the cosine similarities in the top row), and also exhibit greater gradient norm variance across inputs.
  • Figure 4: Comparison of interventions aimed at addressing different failure modes: overall, combining layer normalization with L2 regularization addressed plasticity loss in all classification problems we considered. Many other strategies also improve over doing nothing, but do not outperform this baseline.
  • Figure 5: Left. Layer normalization and L2 regularization on synthetic non-stationary supervised classification problems. Right. necessity of scale-invariant output parameterization in a simple RL task.
  • ...and 23 more figures