Table of Contents
Fetching ...

BP(λ): Online Learning via Synthetic Gradients

Joseph Pemberton, Rui Ponte Costa

TL;DR

The paper addresses the cost and delay of backpropagation through time by introducing accumulate $BP(\lambda)$, an online method to learn synthetic gradients without BPTT using forward eligibility traces. It analytically shows that accumulate $BP(\lambda)$ approximates online $\lambda$-SG, thereby reducing bootstrapping bias, and empirically demonstrates improved gradient alignment and learning on toy tasks, sequential MNIST, and copy-repeat tasks. The approach improves the handling of long-range temporal dependencies and offers insights into biological plausibility through online learning and the role of eligibility traces. This work provides a bias-free, online alternative for temporal supervised learning with potential applications in both AI systems and neuroscience.

Abstract

Training recurrent neural networks typically relies on backpropagation through time (BPTT). BPTT depends on forward and backward passes to be completed, rendering the network locked to these computations before loss gradients are available. Recently, Jaderberg et al. proposed synthetic gradients to alleviate the need for full BPTT. In their implementation synthetic gradients are learned through a mixture of backpropagated gradients and bootstrapped synthetic gradients, analogous to the temporal difference (TD) algorithm in Reinforcement Learning (RL). However, as in TD learning, heavy use of bootstrapping can result in bias which leads to poor synthetic gradient estimates. Inspired by the accumulate $\mathrm{TD}(λ)$ in RL, we propose a fully online method for learning synthetic gradients which avoids the use of BPTT altogether: accumulate $BP(λ)$. As in accumulate $\mathrm{TD}(λ)$, we show analytically that accumulate $\mathrm{BP}(λ)$ can control the level of bias by using a mixture of temporal difference errors and recursively defined eligibility traces. We next demonstrate empirically that our model outperforms the original implementation for learning synthetic gradients in a variety of tasks, and is particularly suited for capturing longer timescales. Finally, building on recent work we reflect on accumulate $\mathrm{BP}(λ)$ as a principle for learning in biological circuits. In summary, inspired by RL principles we introduce an algorithm capable of bias-free online learning via synthetic gradients.

BP(λ): Online Learning via Synthetic Gradients

TL;DR

The paper addresses the cost and delay of backpropagation through time by introducing accumulate , an online method to learn synthetic gradients without BPTT using forward eligibility traces. It analytically shows that accumulate approximates online -SG, thereby reducing bootstrapping bias, and empirically demonstrates improved gradient alignment and learning on toy tasks, sequential MNIST, and copy-repeat tasks. The approach improves the handling of long-range temporal dependencies and offers insights into biological plausibility through online learning and the role of eligibility traces. This work provides a bias-free, online alternative for temporal supervised learning with potential applications in both AI systems and neuroscience.

Abstract

Training recurrent neural networks typically relies on backpropagation through time (BPTT). BPTT depends on forward and backward passes to be completed, rendering the network locked to these computations before loss gradients are available. Recently, Jaderberg et al. proposed synthetic gradients to alleviate the need for full BPTT. In their implementation synthetic gradients are learned through a mixture of backpropagated gradients and bootstrapped synthetic gradients, analogous to the temporal difference (TD) algorithm in Reinforcement Learning (RL). However, as in TD learning, heavy use of bootstrapping can result in bias which leads to poor synthetic gradient estimates. Inspired by the accumulate in RL, we propose a fully online method for learning synthetic gradients which avoids the use of BPTT altogether: accumulate . As in accumulate , we show analytically that accumulate can control the level of bias by using a mixture of temporal difference errors and recursively defined eligibility traces. We next demonstrate empirically that our model outperforms the original implementation for learning synthetic gradients in a variety of tasks, and is particularly suited for capturing longer timescales. Finally, building on recent work we reflect on accumulate as a principle for learning in biological circuits. In summary, inspired by RL principles we introduce an algorithm capable of bias-free online learning via synthetic gradients.
Paper Structure (20 sections, 10 theorems, 37 equations, 7 figures, 2 tables, 1 algorithm)

This paper contains 20 sections, 10 theorems, 37 equations, 7 figures, 2 tables, 1 algorithm.

Key Result

Theorem 3.1

Let $\theta_{0}$ be the initial weight vector, $\theta_{t}^{BP}$ be the weight vector at time $t$ computed by accumulate $BP(\lambda)$, and $\theta_{t}^{\lambda}$ be the weight vector at time t computed by the online $\lambda$-SG algorithm. Furthermore, assume that $\sum_{i=0}^{t-1} \Delta_{i}^{t}$

Figures (7)

  • Figure 1: Schematic of a recurrent neural network (RNN) which learns via synthetic gradients. (a) External input $x_t$ is provided to the RNN which has hidden state $h_t$. Due to recurrency this state will affect the task loss at the current timestep $L_t$ and future timesteps $L_{>t}$ not yet seen. A distinct synthesiser network receives $h_t$ as input and estimates its future loss gradient $\hat{G}_t \approx \frac{\partial L_{>t}}{\partial h_t}$, which is provided to the RNN for learning. The synthesiser learns to mimic a target gradient $v_t$. How $v_t$ is defined and learned is the focus of this paper. (b) An illustration of the accumulate $\mathrm{BP}(\lambda)$ algorithm for learning synthetic gradients in an unrolled version of the network. Current activity $h_t$ must be correctly associated to the later task loss $L_T$. Here the parameters $\theta$ of the synthesiser are updated via a mixture of temporal difference errors $\delta$ (red) and eligibility traces $e$ (green). As in accumulate $\mathrm{TD}(\lambda)$ in RL, $\delta$ is computed online using bootstrapping whilst $e$ propagates forwards with a decay component $\lambda$ with $0 \leq \lambda \leq 1$. Together, they approximate the true loss gradient. In contrast to the original synthetic gradient algorithm by jaderberg2017decoupled, our model does not require BPTT.
  • Figure 2: $BP(\lambda)$ derives true BPTT gradients over multiple timesteps. In this toy paradigm, input is only provided at timestep 1 and the task target is only available at the end of the task at time $T=10$. (a) Alignment between synthetic gradients and true gradients for a fixed RNN model across different timesteps within the task, where the synthetic gradients are learned using (accumulate) $BP(\lambda)$. Alignment is defined using the cosine similarity metric. (b) The average alignment over the last $10\%$ of epochs in a across all timesteps.
  • Figure 3: $\mathrm{BP}(\lambda)$ drives better RNN learning in a toy task. (a) Average cosine similarity between synthetic gradients and true gradients for fixed (left) and plastic (right) RNNs. Cosine similarity for plastic RNNs is taken over the first 5 training epochs, since this initial period is the key stage of learning (i.e. before the task is perfected). (b) Learning curves of RNNs which are updated using synthetic gradients derived by $\mathrm{BP}(\lambda)$ over different sequence lengths $T$. Results show average (with SEM) over 5 different initial conditions.
  • Figure 4: Performance of $\mathrm{BP}(\lambda)$ in sequential MNIST task. (a) Schematic of task. Rows of an MNIST image are fed sequentially as input and the model must classify the digit at the end. (b) Validation accuracy during training for $\mathrm{BP}(\lambda)$ models. (c) Validation accuracy during training for models which learn synthetic gradients (SG) with $n$-step truncated BPTT as in original implementation jaderberg2017decoupled; final performance of $\mathrm{BP}(1)$ (as in (b); dotted green) is given for reference. Results show mean performance over 5 different initial conditions with shaded areas representing standard error of the mean.
  • Figure 5: Performance of $\mathrm{BP}(\lambda)$ in copy-repeat task. (a) Maximum sequence length solved for $\mathrm{BP}(\lambda)$ models. A sequence length is considered as solved if the model achieves an average of $0.15$ bits error for a given length. (b) Maximum sequence length solved for models with $n$-step synthetic gradient (SG) learning methods jaderberg2017decoupled; best task performance of $\mathrm{BP}(1)$ (as in (a); dotted green) is shown for reference. See also Table \ref{['tab:task_results']} for more details. Results show mean performance over 5 different initial conditions with shaded areas representing standard error of the mean.
  • ...and 2 more figures

Theorems & Definitions (10)

  • Theorem 3.1
  • Lemma A.1
  • Lemma A.2
  • Lemma A.3
  • Lemma A.4
  • Lemma A.5
  • Lemma A.6
  • Proposition A.7
  • Lemma B.1
  • Proposition B.2