On Tuning Neural ODE for Stability, Consistency and Faster Convergence
Sheikh Waqas Akhtar
TL;DR
This work identifies the ODE-solver in Neural-ODEs as a bottleneck due to CCS-related stability, convergence, and forward-evaluation issues. It introduces a first-order Nesterov's accelerated gradient (NAG) based ODE-solver tuned to CCS conditions, linking the solver design to linear multi-step methods and demonstrating near-parallel or superior performance to fixed-step solvers and ResNet across classification, density estimation, and time-series tasks. The results show faster convergence with competitive accuracy and highlight task-dependent benefits, along with practical guidelines for CCS verification and potential future directions such as adaptive Lipschitz estimation and implicit methods. Overall, the approach provides a principled, solver-centric path to accelerate Neural-ODEs while preserving stability and consistency in training dynamics.
Abstract
Neural-ODE parameterize a differential equation using continuous depth neural network and solve it using numerical ODE-integrator. These models offer a constant memory cost compared to models with discrete sequence of hidden layers in which memory cost increases linearly with the number of layers. In addition to memory efficiency, other benefits of neural-ode include adaptability of evaluation approach to input, and flexibility to choose numerical precision or fast training. However, despite having all these benefits, it still has some limitations. We identify the ODE-integrator (also called ODE-solver) as the weakest link in the chain as it may have stability, consistency and convergence (CCS) issues and may suffer from slower convergence or may not converge at all. We propose a first-order Nesterov's accelerated gradient (NAG) based ODE-solver which is proven to be tuned vis-a-vis CCS conditions. We empirically demonstrate the efficacy of our approach by training faster, while achieving better or comparable performance against neural-ode employing other fixed-step explicit ODE-solvers as well discrete depth models such as ResNet in three different tasks including supervised classification, density estimation, and time-series modelling.
