Table of Contents
Fetching ...

Second-Order, First-Class: A Composable Stack for Curvature-Aware Training

Mikalai Korbit, Mario Zanon

Abstract

Second-order methods promise improved stability and faster convergence, yet they remain underused due to implementation overhead, tuning brittleness, and the lack of composable APIs. We introduce Somax, a composable Optax-native stack that treats curvature-aware training as a single JIT-compiled step governed by a static plan. Somax exposes first-class modules -- curvature operators, estimators, linear solvers, preconditioners, and damping policies -- behind a single step interface and composes with Optax by applying standard gradient transformations (e.g., momentum, weight decay, schedules) to the computed direction. This design makes typically hidden choices explicit and swappable. Somax separates planning from execution: it derives a static plan (including cadences) from module requirements, then runs the step through a specialized execution path that reuses intermediate results across modules. We report system-oriented ablations showing that (i) composition choices materially affect scaling behavior and time-to-accuracy, and (ii) planning reduces per-step overhead relative to unplanned composition with redundant recomputation.

Second-Order, First-Class: A Composable Stack for Curvature-Aware Training

Abstract

Second-order methods promise improved stability and faster convergence, yet they remain underused due to implementation overhead, tuning brittleness, and the lack of composable APIs. We introduce Somax, a composable Optax-native stack that treats curvature-aware training as a single JIT-compiled step governed by a static plan. Somax exposes first-class modules -- curvature operators, estimators, linear solvers, preconditioners, and damping policies -- behind a single step interface and composes with Optax by applying standard gradient transformations (e.g., momentum, weight decay, schedules) to the computed direction. This design makes typically hidden choices explicit and swappable. Somax separates planning from execution: it derives a static plan (including cadences) from module requirements, then runs the step through a specialized execution path that reuses intermediate results across modules. We report system-oriented ablations showing that (i) composition choices materially affect scaling behavior and time-to-accuracy, and (ii) planning reduces per-step overhead relative to unplanned composition with redundant recomputation.

Paper Structure

This paper contains 46 sections, 8 equations, 3 figures, 6 tables.

Figures (3)

  • Figure 1: Somax system architecture. The assembler and planner resolve a user configuration into a static execution plan. This plan records one lane-specific execution path (diagonal, parameter-space, or row-space), the step-metric outputs, and cadence gates; the resulting step is then JIT-compiled. The executor manages the full update lifecycle -- from linearization to Optax application and optional post-update control signals -- ensuring that post-update metrics are consistent with the applied update.
  • Figure 2: Steady-state step time for matched parameter-space and row-space configurations under a shared training interface. Left: increasing model size at fixed batch size. Right: increasing batch size at fixed model size.
  • Figure 3: Wall-clock performance comparison of SGDM and Sophia-style optimizer regimes on CIFAR-10 with ResNet-20. Thick lines show mean test accuracy across 5 seeds; shaded bands indicate $\pm 1$ standard deviation.