Efficient, Accurate and Stable Gradients for Neural ODEs
Sam McCallum, James Foster
TL;DR
This work tackles the memory and time bottlenecks of training Neural ODEs by introducing a general class of algebraically reversible ODE solvers. These solvers enable exact gradient computation with $O(n)$ time and $O(1)$ memory, outperforming online recursive checkpointing and preserving high-order convergence and numerical stability. The reversible framework works with any single-step base solver and supports adaptive step sizes, backed by a convergence theorem and stability analysis. Empirically, reversible training yields substantial runtime and memory savings across multiple scientific and chaotic systems while achieving identical training losses, with clear pathways to extending the approach to Neural CDEs, SDEs, PDEs, and implicit models.
Abstract
Training Neural ODEs requires backpropagating through an ODE solve. The state-of-the-art backpropagation method is recursive checkpointing that balances recomputation with memory cost. Here, we introduce a class of algebraically reversible ODE solvers that significantly improve upon both the time and memory cost of recursive checkpointing. The reversible solvers presented calculate exact gradients, are high-order and numerically stable -- strictly improving on previous reversible architectures.
