Table of Contents
Fetching ...

Beyond the Mean: Fisher-Orthogonal Projection for Natural Gradient Descent in Large Batch Training

Yishun Lu, Wesley Armour

TL;DR

The work tackles the challenge of scalable, effective optimization for extremely large mini-batches by enhancing natural-gradient methods with Fisher-orthogonal variance corrections. By combining the average mini-batch gradient with a Fisher-orthogonal component derived from a second mini-batch, and by adaptively tuning both beta and layer-wise step sizes, FOP preserves curvature information that large-batch damping would otherwise wash out. KL-norm analysis and targeted experiments demonstrate that FOP delivers fast convergence and robust generalization across convolutional and transformer architectures, achieving substantial speedups over SGD, AdamW, and prior second-order methods in CIFAR and ImageNet settings, including long-tailed data. These results suggest that FOP enables practical, scalable second-order optimization for large-scale distributed training with reduced sensitivity to hyperparameter tuning.

Abstract

Modern GPUs are equipped with large amounts of high-bandwidth memory, enabling them to support mini-batch sizes of up to tens of thousands of training samples. However, most existing optimizers struggle to perform effectively at such a large batch size. As batch size increases, gradient noise decreases due to averaging over many samples, limiting the ability of first-order methods to escape sharp or suboptimal minima and reach the global minimum. Meanwhile, second-order methods like the natural gradient with Kronecker-Factored Approximate Curvature (KFAC) often require excessively high damping to remain stable at large batch sizes. This high damping effectively washes out the curvature information that gives these methods their advantage, reducing their performance to that of simple gradient descent. In this paper, we introduce Fisher-Orthogonal Projection (FOP), a novel technique that restores the effectiveness of the second-order method at very large batch sizes, enabling scalable training with improved generalization and faster convergence. FOP constructs a variance-aware update direction by leveraging gradients from two sub-batches, enhancing the average gradient with a component of the gradient difference that is orthogonal to the average under the Fisher-metric.

Beyond the Mean: Fisher-Orthogonal Projection for Natural Gradient Descent in Large Batch Training

TL;DR

The work tackles the challenge of scalable, effective optimization for extremely large mini-batches by enhancing natural-gradient methods with Fisher-orthogonal variance corrections. By combining the average mini-batch gradient with a Fisher-orthogonal component derived from a second mini-batch, and by adaptively tuning both beta and layer-wise step sizes, FOP preserves curvature information that large-batch damping would otherwise wash out. KL-norm analysis and targeted experiments demonstrate that FOP delivers fast convergence and robust generalization across convolutional and transformer architectures, achieving substantial speedups over SGD, AdamW, and prior second-order methods in CIFAR and ImageNet settings, including long-tailed data. These results suggest that FOP enables practical, scalable second-order optimization for large-scale distributed training with reduced sensitivity to hyperparameter tuning.

Abstract

Modern GPUs are equipped with large amounts of high-bandwidth memory, enabling them to support mini-batch sizes of up to tens of thousands of training samples. However, most existing optimizers struggle to perform effectively at such a large batch size. As batch size increases, gradient noise decreases due to averaging over many samples, limiting the ability of first-order methods to escape sharp or suboptimal minima and reach the global minimum. Meanwhile, second-order methods like the natural gradient with Kronecker-Factored Approximate Curvature (KFAC) often require excessively high damping to remain stable at large batch sizes. This high damping effectively washes out the curvature information that gives these methods their advantage, reducing their performance to that of simple gradient descent. In this paper, we introduce Fisher-Orthogonal Projection (FOP), a novel technique that restores the effectiveness of the second-order method at very large batch sizes, enabling scalable training with improved generalization and faster convergence. FOP constructs a variance-aware update direction by leveraging gradients from two sub-batches, enhancing the average gradient with a component of the gradient difference that is orthogonal to the average under the Fisher-metric.

Paper Structure

This paper contains 32 sections, 2 theorems, 64 equations, 10 figures, 7 tables, 1 algorithm.

Key Result

Lemma 1

Let $g_{\text{diff}}, g_{\text{avg}} \in \mathbb{R}^n$, and let $F \in \mathbb{R}^{n \times n}$ be a symmetric positive semi-definite matrix (the Fisher information matrix). Define the scalar projection: and the orthogonal component: for some small constant $\epsilon$. Then the Fisher inner product between $g_{\text{diff}}^\perp$ and $g_{\text{avg}}$ satisfies: In particular:

Figures (10)

  • Figure 1: 3D loss landscape of training ResNet-18 with CIFAR-10 for batch size of 1024. This visualization relies on the method suggested in li2018visualizing. Arrows represent the direction of the steps of different gradients. The green star is the smallest loss after updating the model based on different update directions.
  • Figure 2: Test accuracy vs. wall-clock time (in seconds) for ResNet-18 on CIFAR-10, grouped by batch size. The dotted line represents the threshold of 91%.
  • Figure 3: Test accuracy vs. wall-clock time (in seconds) for T2T-ViT on ImageNet-100, grouped by batch size. The dotted line represents the threshold of 80.6%.
  • Figure 4: Test accuracy vs. wall-clock time (in minutes) for ResNet-50 on ImageNet-1K, grouped by batch size. The dotted line represents the threshold of 75.9%.
  • Figure 5: Ablation of the scaling parameter $\eta$ on CIFAR-10. Bars are grouped by optimiser KFAC (left cluster) and FOP (right cluster), and coloured by the value of the scaling term $\eta\in\{-0.1,-0.01,0,0.01,0.1\}$. Numbers above each bar give the Top 3 test accuracy averaged over three random seeds.
  • ...and 5 more figures

Theorems & Definitions (4)

  • Lemma 1
  • proof
  • Lemma 2
  • proof