Table of Contents
Fetching ...

Efficient, Accurate and Stable Gradients for Neural ODEs

Sam McCallum, James Foster

TL;DR

This work tackles the memory and time bottlenecks of training Neural ODEs by introducing a general class of algebraically reversible ODE solvers. These solvers enable exact gradient computation with $O(n)$ time and $O(1)$ memory, outperforming online recursive checkpointing and preserving high-order convergence and numerical stability. The reversible framework works with any single-step base solver and supports adaptive step sizes, backed by a convergence theorem and stability analysis. Empirically, reversible training yields substantial runtime and memory savings across multiple scientific and chaotic systems while achieving identical training losses, with clear pathways to extending the approach to Neural CDEs, SDEs, PDEs, and implicit models.

Abstract

Training Neural ODEs requires backpropagating through an ODE solve. The state-of-the-art backpropagation method is recursive checkpointing that balances recomputation with memory cost. Here, we introduce a class of algebraically reversible ODE solvers that significantly improve upon both the time and memory cost of recursive checkpointing. The reversible solvers presented calculate exact gradients, are high-order and numerically stable -- strictly improving on previous reversible architectures.

Efficient, Accurate and Stable Gradients for Neural ODEs

TL;DR

This work tackles the memory and time bottlenecks of training Neural ODEs by introducing a general class of algebraically reversible ODE solvers. These solvers enable exact gradient computation with time and memory, outperforming online recursive checkpointing and preserving high-order convergence and numerical stability. The reversible framework works with any single-step base solver and supports adaptive step sizes, backed by a convergence theorem and stability analysis. Empirically, reversible training yields substantial runtime and memory savings across multiple scientific and chaotic systems while achieving identical training losses, with clear pathways to extending the approach to Neural CDEs, SDEs, PDEs, and implicit models.

Abstract

Training Neural ODEs requires backpropagating through an ODE solve. The state-of-the-art backpropagation method is recursive checkpointing that balances recomputation with memory cost. Here, we introduce a class of algebraically reversible ODE solvers that significantly improve upon both the time and memory cost of recursive checkpointing. The reversible solvers presented calculate exact gradients, are high-order and numerically stable -- strictly improving on previous reversible architectures.

Paper Structure

This paper contains 30 sections, 6 theorems, 41 equations, 3 figures, 3 tables, 1 algorithm.

Key Result

Theorem 2.1

For a fixed time-horizon $T>0$, we consider the Neural ODE in eq:neural-ode over $[0, T]$. Let $T=Nh$ where $N>0$ denotes the number of steps and $h>0$ is the step size. Let $\Psi$ be a $k$-th order ODE solver satisfying the Lipschitz condition (see def:lip-solver) and consider the reversible soluti

Figures (3)

  • Figure 1: Runtime complexity of reversible backpropagation algorithm vs recursive checkpointing with $c$ checkpoints.
  • Figure 2: Computation graph of the reversible method. (a) Forward solve. (b) Backward solve.
  • Figure 3: Reversible solver convergence (dashed) is inherited from base solver (solid).

Theorems & Definitions (16)

  • Theorem 2.1
  • proof
  • Definition 2.2: Linear Stability
  • Theorem 2.3
  • proof
  • Example
  • Definition 1.1
  • Definition 1.2
  • Lemma 1.3
  • proof
  • ...and 6 more