Table of Contents
Fetching ...

Adaptive Optimization via Momentum on Variance-Normalized Gradients

Francisco Patitucci, Aryan Mokhtari

TL;DR

MVN-Grad introduces momentum after variance-based normalization to decouple carry-over momentum from the stochastic normalizer, addressing temporal coupling and sign-collapse in Adam-style optimizers. The approach replaces the uncentered second moment with a variance proxy and demonstrates a formal reduction in one-step update variance and uniform spike robustness. In low-variance (high-signal) regimes, variance normalization preserves gradient magnitudes, enabling faster convergence compared to second-moment normalization. Empirically, MVN-Grad matches or surpasses Adam, AdaBelief, and LaProp on CIFAR-100 and GPT-scale language modeling tasks, with smoother training and improved generalization and robustness.

Abstract

We introduce MVN-Grad (Momentum on Variance-Normalized Gradients), an Adam-style optimizer that improves stability and performance by combining two complementary ideas: variance-based normalization and momentum applied after normalization. MVN-Grad scales each coordinate by an exponential moving average of gradient uncertainty and applies momentum to the resulting normalized gradients, eliminating the cross-time coupling between stale momentum and a stochastic normalizer present in standard Adam-type updates. We prove that this decoupling yields strictly smaller one-step conditional update variance than momentum-then-normalize variance methods under standard noise assumptions, and that MVN-Grad is robust to outliers: it has a uniformly bounded response to single gradient spikes. In low-variance regimes, we further show variance normalization avoids sign-type collapse associated with second-moment scaling and can yield accelerated convergence. Across CIFAR-100 image classification and GPT-style language modeling benchmarks, MVN-Grad matches or outperforms Adam, AdaBelief, and LaProp, delivering smoother training and improved generalization with no added overhead.

Adaptive Optimization via Momentum on Variance-Normalized Gradients

TL;DR

MVN-Grad introduces momentum after variance-based normalization to decouple carry-over momentum from the stochastic normalizer, addressing temporal coupling and sign-collapse in Adam-style optimizers. The approach replaces the uncentered second moment with a variance proxy and demonstrates a formal reduction in one-step update variance and uniform spike robustness. In low-variance (high-signal) regimes, variance normalization preserves gradient magnitudes, enabling faster convergence compared to second-moment normalization. Empirically, MVN-Grad matches or surpasses Adam, AdaBelief, and LaProp on CIFAR-100 and GPT-scale language modeling tasks, with smoother training and improved generalization and robustness.

Abstract

We introduce MVN-Grad (Momentum on Variance-Normalized Gradients), an Adam-style optimizer that improves stability and performance by combining two complementary ideas: variance-based normalization and momentum applied after normalization. MVN-Grad scales each coordinate by an exponential moving average of gradient uncertainty and applies momentum to the resulting normalized gradients, eliminating the cross-time coupling between stale momentum and a stochastic normalizer present in standard Adam-type updates. We prove that this decoupling yields strictly smaller one-step conditional update variance than momentum-then-normalize variance methods under standard noise assumptions, and that MVN-Grad is robust to outliers: it has a uniformly bounded response to single gradient spikes. In low-variance regimes, we further show variance normalization avoids sign-type collapse associated with second-moment scaling and can yield accelerated convergence. Across CIFAR-100 image classification and GPT-style language modeling benchmarks, MVN-Grad matches or outperforms Adam, AdaBelief, and LaProp, delivering smoother training and improved generalization with no added overhead.
Paper Structure (45 sections, 8 theorems, 79 equations, 5 figures, 9 tables, 1 algorithm)

This paper contains 45 sections, 8 theorems, 79 equations, 5 figures, 9 tables, 1 algorithm.

Key Result

Theorem 3.1

Assume that, conditional on $\mathcal{F}_{t-1}$, the centered gradient $g_t-\mu_t$ is symmetric, where $\mu_t:=\mathop{\mathrm{\mathbb{E}}}\limits[g_t\mid\mathcal{F}_{t-1}]$. Assume moreover that the EMA tracks the conditional mean, i.e., $m_{t-1}=\mu_t$. Define $\Delta\mathrm{Var}_t := \mathrm{Var}

Figures (5)

  • Figure 1: MNIST training loss for different batch sizes. Full hyperparameters in Appendix \ref{['app:mnist-hparams']}.
  • Figure 2: Conditional update-variance gap, estimated by Monte Carlo at frozen checkpoints and averaged over three seeds.
  • Figure 3: Delayed single-spike robustness: peak update magnitude $\max_{0\le \tau \le T}|\Delta_\tau|$ versus spike size $M$ (log--log). Within each panel, all optimizers use the same hyperparameters.Experimental details are in Appendix \ref{['app:ssr-hparams']}.
  • Figure 4: Hyperparameter robustness across sweep runs (validation): (a) CIFAR-100 bs=128, (b) CIFAR-100 bs=1024, (c) OpenWebText.
  • Figure 5: Hyperparameter robustness across sweep runs (training): (a) CIFAR-100 bs=128, (b) CIFAR-100 bs=1024, (c) OpenWebText.

Theorems & Definitions (18)

  • Remark 2.1
  • Remark 2.2
  • Theorem 3.1
  • Theorem 3.2
  • Remark 3.1
  • Theorem 3.3
  • proof
  • proof
  • Lemma B.1
  • proof
  • ...and 8 more