Table of Contents
Fetching ...

MKOR: Momentum-Enabled Kronecker-Factor-Based Optimizer Using Rank-1 Updates

Mohammad Mozaffari, Sikan Li, Zhao Zhang, Maryam Mehri Dehnavi

TL;DR

MKOR introduces a momentum-enabled Kronecker-factor optimizer that uses rank-1 updates to approximate curvature, dramatically reducing inversion and communication costs to enable frequent second-order updates. The method combines SM-based rank-1 inversions, norm-based stabilization, gradient rescaling, and half-precision computation, with a hybrid MKOR-H that switches to first-order updates when beneficial. Empirical results on BERT-Large-Uncased and GLUE show substantial speedups over both first- and second-order baselines (up to 2.57x vs LAMB and 1.85x vs KAISA) and competitive or superior accuracy, with additional gains on ResNet-50. The proposed approach enables faster end-to-end training of large models while maintaining convergence reliability, offering practical scalability benefits for transformer and CNN workloads.

Abstract

This work proposes a Momentum-Enabled Kronecker-Factor-Based Optimizer Using Rank-1 updates, called MKOR, that improves the training time and convergence properties of deep neural networks (DNNs). Second-order techniques, while enjoying higher convergence rates vs first-order counterparts, have cubic complexity with respect to either the model size and/or the training batch size. Hence they exhibit poor scalability and performance in transformer models, e.g. large language models (LLMs), because the batch sizes in these models scale by the attention mechanism sequence length, leading to large model size and batch sizes. MKOR's complexity is quadratic with respect to the model size, alleviating the computation bottlenecks in second-order methods. Because of their high computation complexity, state-of-the-art implementations of second-order methods can only afford to update the second order information infrequently, and thus do not fully exploit the promise of better convergence from these updates. By reducing the communication complexity of the second-order updates as well as achieving a linear communication complexity, MKOR increases the frequency of second order updates. We also propose a hybrid version of MKOR (called MKOR-H) that mid-training falls backs to a first order optimizer if the second order updates no longer accelerate convergence. Our experiments show that MKOR outperforms state -of-the-art first order methods, e.g. the LAMB optimizer, and best implementations of second-order methods, i.e. KAISA/KFAC, up to 2.57x and 1.85x respectively on BERT-Large-Uncased on 64 GPUs.

MKOR: Momentum-Enabled Kronecker-Factor-Based Optimizer Using Rank-1 Updates

TL;DR

MKOR introduces a momentum-enabled Kronecker-factor optimizer that uses rank-1 updates to approximate curvature, dramatically reducing inversion and communication costs to enable frequent second-order updates. The method combines SM-based rank-1 inversions, norm-based stabilization, gradient rescaling, and half-precision computation, with a hybrid MKOR-H that switches to first-order updates when beneficial. Empirical results on BERT-Large-Uncased and GLUE show substantial speedups over both first- and second-order baselines (up to 2.57x vs LAMB and 1.85x vs KAISA) and competitive or superior accuracy, with additional gains on ResNet-50. The proposed approach enables faster end-to-end training of large models while maintaining convergence reliability, offering practical scalability benefits for transformer and CNN workloads.

Abstract

This work proposes a Momentum-Enabled Kronecker-Factor-Based Optimizer Using Rank-1 updates, called MKOR, that improves the training time and convergence properties of deep neural networks (DNNs). Second-order techniques, while enjoying higher convergence rates vs first-order counterparts, have cubic complexity with respect to either the model size and/or the training batch size. Hence they exhibit poor scalability and performance in transformer models, e.g. large language models (LLMs), because the batch sizes in these models scale by the attention mechanism sequence length, leading to large model size and batch sizes. MKOR's complexity is quadratic with respect to the model size, alleviating the computation bottlenecks in second-order methods. Because of their high computation complexity, state-of-the-art implementations of second-order methods can only afford to update the second order information infrequently, and thus do not fully exploit the promise of better convergence from these updates. By reducing the communication complexity of the second-order updates as well as achieving a linear communication complexity, MKOR increases the frequency of second order updates. We also propose a hybrid version of MKOR (called MKOR-H) that mid-training falls backs to a first order optimizer if the second order updates no longer accelerate convergence. Our experiments show that MKOR outperforms state -of-the-art first order methods, e.g. the LAMB optimizer, and best implementations of second-order methods, i.e. KAISA/KFAC, up to 2.57x and 1.85x respectively on BERT-Large-Uncased on 64 GPUs.
Paper Structure (25 sections, 3 theorems, 20 equations, 12 figures, 7 tables, 1 algorithm)

This paper contains 25 sections, 3 theorems, 20 equations, 12 figures, 7 tables, 1 algorithm.

Key Result

Lemma 3.1

The factors computed using Equation eq:rank1_left_factor_inversion and eq:rank1_right_factor_inversion are all positive-definite.

Figures (12)

  • Figure 1: MKOR for layer $m$ on a single worker. The inputs of MKOR are the activations $A_t^m$, the gradients of the loss function with respect to the inputs $G_t^m$, and the gradients of the loss function with respect to the weights $\nabla_{W^m} \mathcal{L}$. The output is the update values $\Delta W^m$.
  • Figure 2: The training loss of BERT-Large-Uncased using different optimizers.
  • Figure 3: Per-step breakdown of different optimizes on BERT-Large-Uncased (a) and ResNet-50 (b)
  • Figure 4: The sensitivity of MKOR and KAISA for BERT-Large-Uncased and an Autoencoder model (a) and the effect of inversion frequency on the convergence properties of these models (b).
  • Figure 5: Rank-1 error for activation and input gradient covariance matrices for BERT-Large-Uncased pre-training (a, b) and ResNet-50 on ImageNet (c, d).
  • ...and 7 more figures

Theorems & Definitions (6)

  • Lemma 3.1
  • Lemma 3.2
  • Lemma 3.3
  • proof
  • proof
  • proof