Table of Contents
Fetching ...

Dynamical Isometry and a Mean Field Theory of LSTMs and GRUs

Dar Gilboa, Bo Chang, Minmin Chen, Greg Yang, Samuel S. Schoenholz, Ed H. Chi, Jeffrey Pennington

TL;DR

This work develops a mean-field framework to analyze signal propagation and gradient dynamics in LSTMs and GRUs, enabling precise initialization strategies that realize dynamical isometry. By deriving forward propagation time scales and the spectrum of the state-to-state Jacobian, the authors propose a critical initialization that stabilizes training on long sequences and can even improve generalization. The authors validate the theory through multiple long-sequence tasks (Padded MNIST, unrolled MNIST, CIFAR-10, repeated-pixel variants), showing substantial gains over standard initializations. The findings illuminate how initialization hyperparameters control information flow and offer practical guidance for initializing gated recurrent cells and potentially simplifying architectures while preserving trainability.

Abstract

Training recurrent neural networks (RNNs) on long sequence tasks is plagued with difficulties arising from the exponential explosion or vanishing of signals as they propagate forward or backward through the network. Many techniques have been proposed to ameliorate these issues, including various algorithmic and architectural modifications. Two of the most successful RNN architectures, the LSTM and the GRU, do exhibit modest improvements over vanilla RNN cells, but they still suffer from instabilities when trained on very long sequences. In this work, we develop a mean field theory of signal propagation in LSTMs and GRUs that enables us to calculate the time scales for signal propagation as well as the spectral properties of the state-to-state Jacobians. By optimizing these quantities in terms of the initialization hyperparameters, we derive a novel initialization scheme that eliminates or reduces training instabilities. We demonstrate the efficacy of our initialization scheme on multiple sequence tasks, on which it enables successful training while a standard initialization either fails completely or is orders of magnitude slower. We also observe a beneficial effect on generalization performance using this new initialization.

Dynamical Isometry and a Mean Field Theory of LSTMs and GRUs

TL;DR

This work develops a mean-field framework to analyze signal propagation and gradient dynamics in LSTMs and GRUs, enabling precise initialization strategies that realize dynamical isometry. By deriving forward propagation time scales and the spectrum of the state-to-state Jacobian, the authors propose a critical initialization that stabilizes training on long sequences and can even improve generalization. The authors validate the theory through multiple long-sequence tasks (Padded MNIST, unrolled MNIST, CIFAR-10, repeated-pixel variants), showing substantial gains over standard initializations. The findings illuminate how initialization hyperparameters control information flow and offer practical guidance for initializing gated recurrent cells and potentially simplifying architectures while preserving trainability.

Abstract

Training recurrent neural networks (RNNs) on long sequence tasks is plagued with difficulties arising from the exponential explosion or vanishing of signals as they propagate forward or backward through the network. Many techniques have been proposed to ameliorate these issues, including various algorithmic and architectural modifications. Two of the most successful RNN architectures, the LSTM and the GRU, do exhibit modest improvements over vanilla RNN cells, but they still suffer from instabilities when trained on very long sequences. In this work, we develop a mean field theory of signal propagation in LSTMs and GRUs that enables us to calculate the time scales for signal propagation as well as the spectral properties of the state-to-state Jacobians. By optimizing these quantities in terms of the initialization hyperparameters, we derive a novel initialization scheme that eliminates or reduces training instabilities. We demonstrate the efficacy of our initialization scheme on multiple sequence tasks, on which it enables successful training while a standard initialization either fails completely or is orders of magnitude slower. We also observe a beneficial effect on generalization performance using this new initialization.

Paper Structure

This paper contains 34 sections, 3 theorems, 47 equations, 7 figures, 3 tables, 1 algorithm.

Key Result

Lemma 1

For a recurrent neural networks defined by (eq_s_system), the mean squared singular value of the state-to-state Jacobian defined in (eq_m_1_def) and $\chi_{C_{s}}$ that determines the time scale of forward signal propagation (given by (eq_chi_def)) are related by

Figures (7)

  • Figure 1: Critical initialization improves trainability of recurrent networks. Test accuracy for peephole LSTM trained to classify sequences of MNIST digits after 8000 iterations. As the sequence length increases, the network is no longer trainable with standard initialization, but still trainable using critical initialization.
  • Figure 2: Training accuracy on the padded MNIST classification task described in \ref{['sec_padded']} at different sequence lengths $T$ and hyperparameter values $\Theta_0 + \alpha \Theta_1$ for networks with untied weights, with different values of $\Theta_0, \Theta_1$ chosen for each architecture. The dark and light green curves are respectively $3\xi, 6\xi$ where $\xi$ is the theoretical signal propagation time scale in eqn. (\ref{['eq_xi']}). As can be seen, this time scale predicts the transition between the regions of high and low accuracy across the different architectures and directions in hyperparameter space.
  • Figure 3: Squared singular values of the state-to-state Jacobian in eqn. (\ref{['eq_J_def']}) for two choices of hyperparameter settings $\Theta$. The red lines denote the empirical mean and standard deviations, while the dotted lines denote the theoretical prediction based on the calculation described in Section \ref{['sec_jac']}. Note the dramatic difference in the spectrum caused by choosing an initialization that approximately satisfies the dynamical isometry conditions.
  • Figure 4: Training accuracy for unrolled, concatenated MNIST digits (top) and unrolled MNIST digits with replicated pixels (bottom) for different sequence lengths. Left: For shorter sequences the standard and critical initialization perform equivalently. Middle: As the sequence length is increased, training with a critical initialization is faster by orders of magnitude. Right: For very long sequence lengths, training with a standard initialization fails completely (and is unstable from initialization in the lower right panel).
  • Figure 5: Top: Dynamics of the correlations (\ref{['eq_Mc']}) for the GRU with 3 different values of $\mu_f$ as a function of time. The dashed line is the prediction from the mean field calculation, while the red curves are from a simulation of the network with i.i.d. Gaussian inputs. Left: Network with untied weights. Right: Network with tied weights. Bottom: The predicted fixed point of (\ref{['eq_Mc']}) as a function of different $\mu_f$. Left: Network with untied weights. Right: Network with tied weights.
  • ...and 2 more figures

Theorems & Definitions (7)

  • Lemma 1
  • proof
  • proof : Proof of Lemma \ref{['lem_chi_m_1']}
  • Lemma 2
  • proof : Proof of Lemma \ref{['lembivariate']}
  • Lemma 3
  • proof : Proof of Lemma \ref{['lemconvex']}