Table of Contents
Fetching ...

Revisiting Scalable Hessian Diagonal Approximations for Applications in Reinforcement Learning

Mohamed Elsayed, Homayoon Farrahi, Felix Dangel, A. Rupam Mahmood

TL;DR

The paper tackles the challenge of leveraging second-order information by revisiting deterministic Hessian diagonal approximations. It introduces HesScale, a refinement of the BL89 diagonal scheme that computes exact diagonals for the last layer and propagates diagonal estimates with linear cost, with a Gauss-Newton variant HesScaleGN for further simplification. Across supervised and reinforcement learning tasks involving small networks, HesScale-based methods (AdaHesScale and AdaHesScaleGN) achieve superior approximation quality and faster optimization, while a corresponding step-size scaling mechanism based on the HesScale Hessian enhances robustness and stability in RL. The findings suggest that scalable second-order methods powered by HesScale can improve efficiency and reliability in RL and potentially extend to larger models in the future.

Abstract

Second-order information is valuable for many applications but challenging to compute. Several works focus on computing or approximating Hessian diagonals, but even this simplification introduces significant additional costs compared to computing a gradient. In the absence of efficient exact computation schemes for Hessian diagonals, we revisit an early approximation scheme proposed by Becker and LeCun (1989, BL89), which has a cost similar to gradients and appears to have been overlooked by the community. We introduce HesScale, an improvement over BL89, which adds negligible extra computation. On small networks, we find that this improvement is of higher quality than all alternatives, even those with theoretical guarantees, such as unbiasedness, while being much cheaper to compute. We use this insight in reinforcement learning problems where small networks are used and demonstrate HesScale in second-order optimization and scaling the step-size parameter. In our experiments, HesScale optimizes faster than existing methods and improves stability through step-size scaling. These findings are promising for scaling second-order methods in larger models in the future.

Revisiting Scalable Hessian Diagonal Approximations for Applications in Reinforcement Learning

TL;DR

The paper tackles the challenge of leveraging second-order information by revisiting deterministic Hessian diagonal approximations. It introduces HesScale, a refinement of the BL89 diagonal scheme that computes exact diagonals for the last layer and propagates diagonal estimates with linear cost, with a Gauss-Newton variant HesScaleGN for further simplification. Across supervised and reinforcement learning tasks involving small networks, HesScale-based methods (AdaHesScale and AdaHesScaleGN) achieve superior approximation quality and faster optimization, while a corresponding step-size scaling mechanism based on the HesScale Hessian enhances robustness and stability in RL. The findings suggest that scalable second-order methods powered by HesScale can improve efficiency and reliability in RL and potentially extend to larger models in the future.

Abstract

Second-order information is valuable for many applications but challenging to compute. Several works focus on computing or approximating Hessian diagonals, but even this simplification introduces significant additional costs compared to computing a gradient. In the absence of efficient exact computation schemes for Hessian diagonals, we revisit an early approximation scheme proposed by Becker and LeCun (1989, BL89), which has a cost similar to gradients and appears to have been overlooked by the community. We introduce HesScale, an improvement over BL89, which adds negligible extra computation. On small networks, we find that this improvement is of higher quality than all alternatives, even those with theoretical guarantees, such as unbiasedness, while being much cheaper to compute. We use this insight in reinforcement learning problems where small networks are used and demonstrate HesScale in second-order optimization and scaling the step-size parameter. In our experiments, HesScale optimizes faster than existing methods and improves stability through step-size scaling. These findings are promising for scaling second-order methods in larger models in the future.
Paper Structure (28 sections, 1 theorem, 48 equations, 13 figures, 1 table, 3 algorithms)

This paper contains 28 sections, 1 theorem, 48 equations, 13 figures, 1 table, 3 algorithms.

Key Result

Theorem 6.1

HesScale Computation with CNNs. Under the zero second-order off-diagonals assumption in all layers of a neural network except for the last one, the second derivatives can be propagated with linear complexity in the number of network parameters and in the network's output dimension using the followin

Figures (13)

  • Figure 1: Backpropagating the exact Hessian information in a neural network. Red arrows represent the direction of influence while backpropagating the Hessian of the loss w.r.t. pre-activations which are then used to compute the Hessian of the loss w.r.t. the weights at each layer, denoted by the blue arrows. Black arrows denote the direction of influence during the forward pass.
  • Figure 2: (a) The average error for each method is normalized by the average error incurred by HesScale. Each colored point represents a different initialization. The norm of the vector of Hessian diagonals $|\operatorname{diag}({\bm{H}})|$ is shown as a reference. (b) The average layer-wise error for each method is shown. HesScale(GN) has zero error at the last layer since it uses the exact entries there.
  • Figure 3: Heat maps of Hessian of the loss w.r.t. pre-activations. $\{\nabla^2_{{\bm{a}}_i}\mathcal{L}\}_{i=1}^{4}$ visually appear diagonally dominant. Red represents a small magnitude, and blue represents a large magnitude.
  • Figure 4: CIFAR-100 3C3D and CIFAR-100 ALL-CNN classification tasks. (top) We show the time taken by each algorithm in seconds, and (bottom) we show the learning curves in the number of epochs. The shaded area represents the standard error. BL89 achieves lower than $0.35$ and is not visible.
  • Figure 5: Performance of A2C (first row) and PPO (second row) with AdaHesScale, Adam, and AdaHessian on $5$ MuJoCo environments. We show the undiscounted return averaged over $10$ independent runs. The shaded area represents the standard error.
  • ...and 8 more figures

Theorems & Definitions (2)

  • Theorem 6.1
  • proof