Table of Contents
Fetching ...

The Rank and Gradient Lost in Non-stationarity: Sample Weight Decay for Mitigating Plasticity Loss in Reinforcement Learning

Zihao Wu, Hongyao Tang, Yi Ma, Jiashun Liu, Yan Zheng, Jianye Hao

Abstract

Deep reinforcement learning (RL) suffers from plasticity loss severely due to the nature of non-stationarity, which impairs the ability to adapt to new data and learn continually. Unfortunately, our understanding of how plasticity loss arises, dissipates, and can be dissolved remains limited to empirical findings, leaving the theoretical end underexplored.To address this gap, we study the plasticity loss problem from the theoretical perspective of network optimization. By formally characterizing the two culprit factors in online RL process: the non-stationarity of data distributions and the non-stationarity of targets induced by bootstrapping, our theory attributes the loss of plasticity to two mechanisms: the rank collapse of the Neural Tangent Kernel (NTK) Gram matrix and the $Θ(\frac{1}{k})$ decay of gradient magnitude. The first mechanism echoes prior empirical findings from the theoretical perspective and sheds light on the effects of existing methods, e.g., network reset, neuron recycle, and noise injection. Against this backdrop, we focus primarily on the second mechanism and aim to alleviate plasticity loss by addressing the gradient attenuation issue, which is orthogonal to existing methods. We propose Sample Weight Decay -- a lightweight method to restore gradient magnitude, as a general remedy to plasticity loss for deep RL methods based on experience replay. In experiments, we evaluate the efficacy of \methodName upon TD3, \myadded{Double DQN} and SAC with SimBa architecture in MuJoCo, \myadded{ALE} and DeepMind Control Suite tasks. The results demonstrate that \methodName effectively alleviates plasticity loss and consistently improves learning performance across various configurations of deep RL algorithms, UTD, network architectures, and environments, achieving SOTA performance on challenging DMC Humanoid tasks.

The Rank and Gradient Lost in Non-stationarity: Sample Weight Decay for Mitigating Plasticity Loss in Reinforcement Learning

Abstract

Deep reinforcement learning (RL) suffers from plasticity loss severely due to the nature of non-stationarity, which impairs the ability to adapt to new data and learn continually. Unfortunately, our understanding of how plasticity loss arises, dissipates, and can be dissolved remains limited to empirical findings, leaving the theoretical end underexplored.To address this gap, we study the plasticity loss problem from the theoretical perspective of network optimization. By formally characterizing the two culprit factors in online RL process: the non-stationarity of data distributions and the non-stationarity of targets induced by bootstrapping, our theory attributes the loss of plasticity to two mechanisms: the rank collapse of the Neural Tangent Kernel (NTK) Gram matrix and the decay of gradient magnitude. The first mechanism echoes prior empirical findings from the theoretical perspective and sheds light on the effects of existing methods, e.g., network reset, neuron recycle, and noise injection. Against this backdrop, we focus primarily on the second mechanism and aim to alleviate plasticity loss by addressing the gradient attenuation issue, which is orthogonal to existing methods. We propose Sample Weight Decay -- a lightweight method to restore gradient magnitude, as a general remedy to plasticity loss for deep RL methods based on experience replay. In experiments, we evaluate the efficacy of \methodName upon TD3, \myadded{Double DQN} and SAC with SimBa architecture in MuJoCo, \myadded{ALE} and DeepMind Control Suite tasks. The results demonstrate that \methodName effectively alleviates plasticity loss and consistently improves learning performance across various configurations of deep RL algorithms, UTD, network architectures, and environments, achieving SOTA performance on challenging DMC Humanoid tasks.

Paper Structure

This paper contains 56 sections, 7 theorems, 39 equations, 10 figures, 13 tables, 2 algorithms.

Key Result

Proposition 1

The empirical distribution satisfies $\blacktriangleleft$$\blacktriangleleft$

Figures (10)

  • Figure 1: Aggregate Reliable metrics agarwal2021deep with 95% Stratified Bootstrap CIS.
  • Figure 2: Empirical validation of SWD across TD3 in MuJoCo environments (mean $\pm$ std over 5 runs). SWD consistently improves sample efficiency and performance.
  • Figure 3: Empirical validation of SWD across Double DQN in ALE environments (mean $\pm$ std over 5 runs). SWD consistently improves sample efficiency and performance.
  • Figure 4: Performance comparison between SWD and PER based on SAC. Aggregate Reliable metrics agarwal2021deep with 95% Stratified Bootstrap CIS in DMC tasks.
  • Figure 5: Experiments conducted in the humanoid-run environment demonstrate that SWA exhibits a lower gradient magnitude, GraMa, and inferior performance, which validates our hypothesis.
  • ...and 5 more figures

Theorems & Definitions (12)

  • Proposition 1: Empirical distribution recursion
  • proof : Proof (sketch)
  • Theorem 1: Population loss limit
  • Theorem 2: Suboptimality bound via squared bellman residuals
  • Theorem 3: Gradient Dynamics at Initialization
  • Lemma 1
  • proof
  • proof
  • Lemma 2
  • proof
  • ...and 2 more