Muon is Provably Faster with Momentum Variance Reduction
Xun Qian, Hussein Rammal, Dmitry Kovalev, Peter Richtárik
TL;DR
The paper addresses the inefficiency of vanilla momentum in LMO-based optimizers for large-language-model training. It introduces Momentum Variance Reduction (MVR) within the Gluon framework, yielding three variants (Gluon-MVR-1/2/3) that achieve faster non-convex convergence rates under layer-wise $(L^0,L^1)$-smoothness, and improve star-convex rates for Muon-MVR. Theoretical results establish ${\tilde{O}}(1/K^{1/3})$ rates (and ${\tilde{O}}(1/K^{1/2})$ in star-convex Muon-MVR) under realistic assumptions, with constant- and decreasing-step-size analyses providing complementary insights. Empirical tests on NanoGPT-124M trained on FineWeb-10B corroborate the theoretical benefits, showing improved iteration efficiency and lower validation losses for MVR-enhanced Gluon variants, particularly in larger or more noisy regimes.
Abstract
Recent empirical research has demonstrated that deep learning optimizers based on the linear minimization oracle (LMO) over specifically chosen Non-Euclidean norm balls, such as Muon and Scion, outperform Adam-type methods in the training of large language models. In this work, we show that such optimizers can be provably improved by replacing their vanilla momentum by momentum variance reduction (MVR). Instead of proposing and analyzing MVR variants of Muon and Scion separately, we incorporate MVR into the recently proposed Gluon framework, which captures Muon, Scion and other specific Non-Euclidean LMO-based methods as special cases, and at the same time works with a more general smoothness assumption which better captures the layer-wise structure of neural networks. In the non-convex case, we incorporate MVR into Gluon in three different ways. All of them improve the convergence rate from ${\cal O} (\frac{1}{K^{1/4}})$ to ${\cal O} (\frac{1}{K^{1/3}})$. Additionally, we provide improved rates in the star-convex case. Finally, we conduct several numerical experiments that verify the superior performance of our proposed algorithms in terms of iteration complexity.
