Table of Contents
Fetching ...

KOALA++: Efficient Kalman-Based Optimization with Gradient-Covariance Products

Zixuan Xia, Aram Davtyan, Paolo Favaro

TL;DR

KOALA++ advances Kalman-based optimization by propagating a directional gradient-covariance surrogate $v_k = H_k P_{k-1}$ to capture structured uncertainty without storing the full covariance $P_k$. It introduces low-rank reparameterizations, a recursive vk update, and two least-squares covariance-estimation variants (vanilla and symmetric), yielding an update that remains near first-order in cost while incorporating directional curvature information. Empirically, KOALA++ matches or surpasses strong first- and second-order baselines across image classification and language modeling benchmarks, with favorable stability and efficiency. The approach offers a practical bridge between expressiveness and scalability for large-scale neural optimization, with potential for integration into pretraining and transformer-based workloads. Future work includes enforcing positive semi-definiteness of the covariance surrogate and extending the method to larger-scale, real-world models.

Abstract

We propose KOALA++, a scalable Kalman-based optimization algorithm that explicitly models structured gradient uncertainty in neural network training. Unlike second-order methods, which rely on expensive second order gradient calculation, our method directly estimates the parameter covariance matrix by recursively updating compact gradient covariance products. This design improves upon the original KOALA framework that assumed diagonal covariance by implicitly capturing richer uncertainty structure without storing the full covariance matrix and avoiding large matrix inversions. Across diverse tasks, including image classification and language modeling, KOALA++ achieves accuracy on par or better than state-of-the-art first- and second-order optimizers while maintaining the efficiency of first-order methods.

KOALA++: Efficient Kalman-Based Optimization with Gradient-Covariance Products

TL;DR

KOALA++ advances Kalman-based optimization by propagating a directional gradient-covariance surrogate to capture structured uncertainty without storing the full covariance . It introduces low-rank reparameterizations, a recursive vk update, and two least-squares covariance-estimation variants (vanilla and symmetric), yielding an update that remains near first-order in cost while incorporating directional curvature information. Empirically, KOALA++ matches or surpasses strong first- and second-order baselines across image classification and language modeling benchmarks, with favorable stability and efficiency. The approach offers a practical bridge between expressiveness and scalability for large-scale neural optimization, with potential for integration into pretraining and transformer-based workloads. Future work includes enforcing positive semi-definiteness of the covariance surrogate and extending the method to larger-scale, real-world models.

Abstract

We propose KOALA++, a scalable Kalman-based optimization algorithm that explicitly models structured gradient uncertainty in neural network training. Unlike second-order methods, which rely on expensive second order gradient calculation, our method directly estimates the parameter covariance matrix by recursively updating compact gradient covariance products. This design improves upon the original KOALA framework that assumed diagonal covariance by implicitly capturing richer uncertainty structure without storing the full covariance matrix and avoiding large matrix inversions. Across diverse tasks, including image classification and language modeling, KOALA++ achieves accuracy on par or better than state-of-the-art first- and second-order optimizers while maintaining the efficiency of first-order methods.

Paper Structure

This paper contains 49 sections, 53 equations, 11 figures, 12 tables, 2 algorithms.

Figures (11)

  • Figure 1: Comparison of training loss, validation loss, and validation error for different optimizers on CIFAR-100 using ResNet-50. KOALA++ demonstrates the strongest performance drop at scheduled learning rate decays (epochs 30, 60, and 90), highlighting its superior scheduler responsiveness.
  • Figure 2: Ablation study comparing KOALA++, its asymmetric variant KOALA++ (NS), and KOALA-M on CIFAR-100 with ResNet-50. Left: Log-scale test loss curves. Right: Validation Top-1 error rates.
  • Figure 3: Evolution of the large and small eigenvalues of the matrix $P$ at epoch 0 and epoch 30 across all mini-batches.
  • Figure 4: Empirical verification of the positive semi-definiteness of $M_{k}$ in KOALA++. The plots show the evolution of the angle between the curvature vector $H_{k}$ and the update direction $v_{k}$ across training for three datasets. The consistently acute angles ($<90^\circ$) indicate that $H_{k}$ and $v_{k}$ remain directionally aligned, confirming the effective PSD behavior of $M_{k}$ in practice.
  • Figure 5: CIFAR-10 with ResNet18 (Left) and ResNet50 (Right) trained for 100 epochs under a cosine learning-rate scheduler. Validation Top-1 errors (the inset highlights the late-epoch region) are reported.
  • ...and 6 more figures