Table of Contents
Fetching ...

Natural Neural Networks

Guillaume Desjardins, Karen Simonyan, Razvan Pascanu, Koray Kavukcuoglu

TL;DR

The paper tackles optimization difficulties in deep neural networks arising from ill-conditioned parameter interactions under the KL geometry. It introduces Natural Neural Networks with a whitening-based per-layer reparametrization and a Projected Natural Gradient Descent (PRONG) that amortizes the whitening cost, linking the approach to Mirror Descent. Through theoretical derivations of a layerwise Fisher and extensive experiments from auto-encoders to ImageNet, it demonstrates improved Fisher conditioning, faster convergence, and scalability. The work suggests practical training benefits, potential for model compression, and fruitful connections to online optimization concepts.

Abstract

We introduce Natural Neural Networks, a novel family of algorithms that speed up convergence by adapting their internal representation during training to improve conditioning of the Fisher matrix. In particular, we show a specific example that employs a simple and efficient reparametrization of the neural network weights by implicitly whitening the representation obtained at each layer, while preserving the feed-forward computation of the network. Such networks can be trained efficiently via the proposed Projected Natural Gradient Descent algorithm (PRONG), which amortizes the cost of these reparametrizations over many parameter updates and is closely related to the Mirror Descent online learning algorithm. We highlight the benefits of our method on both unsupervised and supervised learning tasks, and showcase its scalability by training on the large-scale ImageNet Challenge dataset.

Natural Neural Networks

TL;DR

The paper tackles optimization difficulties in deep neural networks arising from ill-conditioned parameter interactions under the KL geometry. It introduces Natural Neural Networks with a whitening-based per-layer reparametrization and a Projected Natural Gradient Descent (PRONG) that amortizes the whitening cost, linking the approach to Mirror Descent. Through theoretical derivations of a layerwise Fisher and extensive experiments from auto-encoders to ImageNet, it demonstrates improved Fisher conditioning, faster convergence, and scalability. The work suggests practical training benefits, potential for model compression, and fruitful connections to online optimization concepts.

Abstract

We introduce Natural Neural Networks, a novel family of algorithms that speed up convergence by adapting their internal representation during training to improve conditioning of the Fisher matrix. In particular, we show a specific example that employs a simple and efficient reparametrization of the neural network weights by implicitly whitening the representation obtained at each layer, while preserving the feed-forward computation of the network. Such networks can be trained efficiently via the proposed Projected Natural Gradient Descent algorithm (PRONG), which amortizes the cost of these reparametrizations over many parameter updates and is closely related to the Mirror Descent online learning algorithm. We highlight the benefits of our method on both unsupervised and supervised learning tasks, and showcase its scalability by training on the large-scale ImageNet Challenge dataset.

Paper Structure

This paper contains 16 sections, 10 equations, 4 figures, 1 algorithm.

Figures (4)

  • Figure 1: (a) A 2-layer natural neural network. (b) Illustration of the projections involved in PRONG.
  • Figure 2: Fisher matrix for a small MLP (a) before and (b) after the first reparametrization. Best viewed in colour. (c) Condition number of the FIM during training, relative to the initial conditioning.
  • Figure 3: Optimizing a deep auto-encoder on MNIST. (a) Impact of eigenvalue regularization term $\epsilon$. (b) Impact of amortization period $T$ showing that initialization with the whitening reparametrization is important for achieving faster learning and better error rate. (c) Training error vs number of updates. (d) Training error vs cpu-time. Plots (c) and (d) show that PRONG achieves better error rate both in number of updates and wall clock time.
  • Figure 4: Classification error on CIFAR-10 (a-b) and ImageNet (c-d). On CIFAR-10, PRONG achieves better test error and converges faster. On ImageNet, PRONG$^+$ achieves comparable validation error while maintaining a faster covergence rate.