Reconstructing Deep Neural Networks: Unleashing the Optimization Potential of Natural Gradient Descent
Weihua Liu, Said Boumaraf, Jianwu Li, Chaochao Lin, Xiabi Liu, Lijuan Niu, Naoufel Werghi
TL;DR
This work tackles the computational bottleneck of Natural Gradient Descent (NGD) in training deep neural networks by introducing Structured Natural Gradient Descent (SNGD), which reconstructs the network with Local Fisher Layers to decompose the global Fisher information matrix into local blocks. A key theoretical result shows that updating a transformed parameter $w' = G^{1/2} w$ with standard Gradient Descent (GD) yields the same optimization trajectory as NGD on the original parameters, enabling efficient training. The authors introduce a hierarchical decomposition and efficient square-root computation for the Fisher matrix using Nyström approximations and Denman–Beavers iterations, culminating in a training procedure that matches NGD’s curvature benefits while remaining practical for deep models. Empirically, SNGD outperforms traditional first-order optimizers and NGD in convergence speed and often generalization, demonstrated across MLPs, CNNs, LSTMs, and ResNets on MNIST, CIFAR-10, ImageNet, and Penn Treebank. The work suggests that NGD-like optimization can be scaled to a broad class of architectures, with potential applicability to transformer-based models.
Abstract
Natural gradient descent (NGD) is a powerful optimization technique for machine learning, but the computational complexity of the inverse Fisher information matrix limits its application in training deep neural networks. To overcome this challenge, we propose a novel optimization method for training deep neural networks called structured natural gradient descent (SNGD). Theoretically, we demonstrate that optimizing the original network using NGD is equivalent to using fast gradient descent (GD) to optimize the reconstructed network with a structural transformation of the parameter matrix. Thereby, we decompose the calculation of the global Fisher information matrix into the efficient computation of local Fisher matrices via constructing local Fisher layers in the reconstructed network to speed up the training. Experimental results on various deep networks and datasets demonstrate that SNGD achieves faster convergence speed than NGD while retaining comparable solutions. Furthermore, our method outperforms traditional GDs in terms of efficiency and effectiveness. Thus, our proposed method has the potential to significantly improve the scalability and efficiency of NGD in deep learning applications. Our source code is available at https://github.com/Chaochao-Lin/SNGD.
