Table of Contents
Fetching ...

Non-Stationary Learning of Neural Networks with Automatic Soft Parameter Reset

Alexandre Galashov, Michalis K. Titsias, András György, Clare Lyle, Razvan Pascanu, Yee Whye Teh, Maneesh Sahani

TL;DR

This work introduces a novel learning approach that automatically models and adapts to non-stationarity, via an Ornstein-Uhlenbeck process with an adaptive drift parameter, that performs well in non-stationary supervised and off-policy reinforcement learning settings.

Abstract

Neural networks are traditionally trained under the assumption that data come from a stationary distribution. However, settings which violate this assumption are becoming more popular; examples include supervised learning under distributional shifts, reinforcement learning, continual learning and non-stationary contextual bandits. In this work we introduce a novel learning approach that automatically models and adapts to non-stationarity, via an Ornstein-Uhlenbeck process with an adaptive drift parameter. The adaptive drift tends to draw the parameters towards the initialisation distribution, so the approach can be understood as a form of soft parameter reset. We show empirically that our approach performs well in non-stationary supervised and off-policy reinforcement learning settings.

Non-Stationary Learning of Neural Networks with Automatic Soft Parameter Reset

TL;DR

This work introduces a novel learning approach that automatically models and adapts to non-stationarity, via an Ornstein-Uhlenbeck process with an adaptive drift parameter, that performs well in non-stationary supervised and off-policy reinforcement learning settings.

Abstract

Neural networks are traditionally trained under the assumption that data come from a stationary distribution. However, settings which violate this assumption are becoming more popular; examples include supervised learning under distributional shifts, reinforcement learning, continual learning and non-stationary contextual bandits. In this work we introduce a novel learning approach that automatically models and adapts to non-stationarity, via an Ornstein-Uhlenbeck process with an adaptive drift parameter. The adaptive drift tends to draw the parameters towards the initialisation distribution, so the approach can be understood as a form of soft parameter reset. We show empirically that our approach performs well in non-stationary supervised and off-policy reinforcement learning settings.

Paper Structure

This paper contains 51 sections, 80 equations, 24 figures, 3 tables, 3 algorithms.

Figures (24)

  • Figure 1: Left: graphical model for data generating process in the (a) stationary case and (b) non-stationary case with drift model $p(\theta_{t+1} | \theta_t, \gamma_t)$. Right: (c) In a stationary online learning regime, the Bayesian posterior (red dashed circles) in the long run will concentrate around $\theta^*$ (red dot). (d) In a non-stationary regime where the optimal parameters suddenly change from current value $\theta_t^*$ to new value $\theta^*_{t+1}$ (blue dot) online Bayesian estimation can be less data efficient and take time to recover when the change-point occurs. (e) The use of $p(\theta|\theta_t, \gamma_t)$ and the estimation of $\gamma_t$ allows to increase the uncertainty, by soft resetting the posterior to make it closer to the prior (green dashed circle), so that the updated Bayesian posterior $p_{t+1}(\theta)$ (blue dashed circle) can faster track $\theta_{t+1}^*$.
  • Figure 2: Plasticity benchmarks. Left: performance on permuted MNIST. Center: performance on random-label MNIST (data efficient). Right: performance on random-label CIFAR-10 (memorization). The x-axis is the task id and the y-axis is the per-task training accuracy \ref{['eq:per_task_avg_accuracy']}.
  • Figure 3: Different variants of Soft Resets. Left: performance on permuted MNIST. Center: performance on random-label MNIST (data efficient). Right: performance on random-label CIFAR-10 (memorization). The x-axis is the task id and the y-axis is the per-task training accuracy \ref{['eq:per_task_avg_accuracy']}.
  • Figure 4: Left: the minimum encountered $\gamma_t$ for each layer on random-label MNIST and CIFAR-10. Center: the dynamics of $\gamma_t$ on the first 20 tasks on MNIST. Right: the same on CIFAR-10.
  • Figure 5: (a) the x-axis denotes the layer, the y-axis denotes the minimum encountered $\gamma_t$ for each convolutional and fully-connected layer when trained on permuted Patches MNIST, color is the patch size. The impact of non-stationarity on performance on random-label MNIST of Online SGD and Hard Reset is shown in (b) while the one of Soft Resets is shown in (c). The x-axis denotes the number of epochs each task lasts, while the marker and line styles denote the percentage of random labels within each task, circle (solid) represents $20\%$, rectangle(dashed) $40\%$, while rhombus (dashed and dot) $60\%$. The y-axis denotes the average performance (over $3$ seeds) on the stream of $200$ tasks.
  • ...and 19 more figures