Table of Contents
Fetching ...

Recurrent neural networks: vanishing and exploding gradients are not the end of the story

Nicolas Zucchet, Antonio Orvieto

TL;DR

It is discovered that, as the memory of a network increases, changes in its parameters result in increasingly large output variations, making gradient-based learning highly sensitive, even without exploding gradients.

Abstract

Recurrent neural networks (RNNs) notoriously struggle to learn long-term memories, primarily due to vanishing and exploding gradients. The recent success of state-space models (SSMs), a subclass of RNNs, to overcome such difficulties challenges our theoretical understanding. In this paper, we delve into the optimization challenges of RNNs and discover that, as the memory of a network increases, changes in its parameters result in increasingly large output variations, making gradient-based learning highly sensitive, even without exploding gradients. Our analysis further reveals the importance of the element-wise recurrence design pattern combined with careful parametrizations in mitigating this effect. This feature is present in SSMs, as well as in other architectures, such as LSTMs. Overall, our insights provide a new explanation for some of the difficulties in gradient-based learning of RNNs and why some architectures perform better than others.

Recurrent neural networks: vanishing and exploding gradients are not the end of the story

TL;DR

It is discovered that, as the memory of a network increases, changes in its parameters result in increasingly large output variations, making gradient-based learning highly sensitive, even without exploding gradients.

Abstract

Recurrent neural networks (RNNs) notoriously struggle to learn long-term memories, primarily due to vanishing and exploding gradients. The recent success of state-space models (SSMs), a subclass of RNNs, to overcome such difficulties challenges our theoretical understanding. In this paper, we delve into the optimization challenges of RNNs and discover that, as the memory of a network increases, changes in its parameters result in increasingly large output variations, making gradient-based learning highly sensitive, even without exploding gradients. Our analysis further reveals the importance of the element-wise recurrence design pattern combined with careful parametrizations in mitigating this effect. This feature is present in SSMs, as well as in other architectures, such as LSTMs. Overall, our insights provide a new explanation for some of the difficulties in gradient-based learning of RNNs and why some architectures perform better than others.
Paper Structure (61 sections, 2 theorems, 98 equations, 15 figures, 3 tables)

This paper contains 61 sections, 2 theorems, 98 equations, 15 figures, 3 tables.

Key Result

Lemma 1

For $\alpha, \beta \in \mathbb{C}$ satisfying $|\alpha| < 1$ and $|\beta| < 1$, and $(u_n)_{n\in \mathbb{Z}}$ a bounded sequence satisfying $u_{-n} = u_n$, we have

Figures (15)

  • Figure 1: Optimization of recurrent neural networks gets harder as their memory increases.A. Evolution of the second moment of $\mathrm{d}_\lambda h_t$ as a function of the recurrent parameter $\lambda$ and of the input $x$ auto-correlation decay rate $\rho$, when $h_{t+1} = \lambda h_t + x_t$. As the memory of the network increases ($\lambda \rightarrow 1$), $h_t$ becomes more sensitive to changes in $\lambda$, particularly as the elements in the input sequence are more correlated ($\rho \rightarrow 1$). The explosion of $\mathrm{d}_\lambda h_t$ is faster than the one of $h_t$, as highlighted with the grey line obtained for $\rho = 1$. See Section \ref{['subsec:sig_prop_diag']} for more detail. B, C. Illustration of the phenomenon on the toy one-dimensional teacher-student task of Section \ref{['subsec:lsi_1D']}, in which the teacher is parametrized by a real number $\lambda^*$ and the student by a complex number $\lambda$. In B, $\lambda$ varies on the real axis, and it varies on the circle of radius $\lambda^*$ parametrized by $\theta$ in C. The loss becomes sharper as information is kept longer in memory, making gradient-based optimization nearly impossible.
  • Figure 2: Illustration of the effects of normalization and reparametrization. It can effectively control the magnitude of A.$\mathbb{E}[h_t^2]$ and B.$\mathbb{E}[(\mathrm{d}_\lambda h_t)^2]$ over all $\lambda$ values when the input auto-correlation satisfies $R_x(\Delta) = \rho^{|\Delta|}$ with $\rho = 0$, but does not manage do to so for other type of distributions ($\rho \neq 0$). Here, we use $\gamma(\lambda) = \sqrt{1 - \lambda^2}$, decouple it from $\lambda$ when differentiating, and take $\lambda = \exp(-\exp(\nu))$, as in orvieto_resurrecting_2023. The grey line indicates the value the two quantities take without any normalization and reparametrization, when $\rho=1$.
  • Figure 3: LRUs are better at replicating a teacher's behavior than linear RNNs.A. As the teacher encodes longer dependencies ($\nu_0 \rightarrow 1$), the linear RNN struggles to reproduce it, but not the LRU. B. An ablation study ($\nu_0=0.99$) reveals that this gap mainly comes from having a close to diagonal recurrent connectivity matrix. See Section \ref{['subsec:diagonal_connectivity']} for more detail.
  • Figure 4: Differences in learning abilities between fully connected and complex diagonal linear RNNs are due to a better structure of the loss landscape.A, B. Hessian of the loss at optimality, its 10 eigenvectors with greatest eigenvalues and its eigenspectra for a fully connected RNN (A) and a complex diagonal one (B). The spectra are almost the same. However, the top eigenvectors are concentrated on few coordinates for the complex diagonal one but not for the fully connected one. C, D. This structure makes it possible for Adam to efficiently deal with the extra sensitivity, as shown with the effective learning rates that it uses at the end of learning. For the fully connected one (C), Adam uses small learning rates to compensate for the sensitivity, whereas it can use larger ones for the complex diagonal one without hindering training stability. The horizontal grey line shows the learning rate used, which is here $10^{-3}$.
  • Figure 5: Signal propagation in deep recurrent networks at initialization is consistent with our theory.A.$\mathbb{E}[h_t^2]$ after the first and the fourth layer, as a function of the exponential decay parameter $\nu_0$, for complex-valued diagonal RNN (cRNN), LRU, and GRU recurrent layers. The input normalization present in the LRU and in the GRU effectively keeps neural activity constant across $\nu_0$ values. B. Comparison of the evolution of the loss gradient $\mathbb{E}[(\mathrm{d}_\theta L)^2]$ for the different recurrent layers and specific groups of parameters. For the complex diagonal RNN, the gradients of all parameters explode and in particular the ones of the recurrent parameters, whereas only the ones of the angle of $\lambda$ explode for the LRU, consistently with the theory. Error signal propagation in GRUs is under control: the magnitude of the gradients is independent of $\nu_0$. The GRU-specific parameters exhibit smaller gradients than the feedforward (ff) ones. C. Layer normalization keeps the overall gradient magnitude under control in cRNNs. Batch normalization yields similar results.
  • ...and 10 more figures

Theorems & Definitions (4)

  • Lemma 1
  • proof
  • Lemma 2
  • proof