Table of Contents
Fetching ...

Training Neural ODEs Using Fully Discretized Simultaneous Optimization

Mariia Shapovalova, Calvin Tsay

TL;DR

This work tackles the high computational cost of training Neural ODEs by using a collocation-based, fully discretized simultaneous optimization approach. By formulating the training as a single nonlinear program with collocation constraints and solving for both collocation states and neural network parameters with IPOPT (via Pyomo), the method achieves rapid convergence. The authors further extend the framework with ADMM to enable batching across data, and demonstrate improved performance on the Van der Pol oscillator compared with traditional sequential training pipelines. The results indicate the approach can yield faster, more data-efficient training and support scalable multi-batch training with potential for compact model representations.

Abstract

Neural Ordinary Differential Equations (Neural ODEs) represent continuous-time dynamics with neural networks, offering advancements for modeling and control tasks. However, training Neural ODEs requires solving differential equations at each epoch, leading to high computational costs. This work investigates simultaneous optimization methods as a faster training alternative. In particular, we employ a collocation-based, fully discretized formulation and use IPOPT--a solver for large-scale nonlinear optimization--to simultaneously optimize collocation coefficients and neural network parameters. Using the Van der Pol Oscillator as a case study, we demonstrate faster convergence compared to traditional training methods. Furthermore, we introduce a decomposition framework utilizing Alternating Direction Method of Multipliers (ADMM) to effectively coordinate sub-models among data batches. Our results show significant potential for (collocation-based) simultaneous Neural ODE training pipelines.

Training Neural ODEs Using Fully Discretized Simultaneous Optimization

TL;DR

This work tackles the high computational cost of training Neural ODEs by using a collocation-based, fully discretized simultaneous optimization approach. By formulating the training as a single nonlinear program with collocation constraints and solving for both collocation states and neural network parameters with IPOPT (via Pyomo), the method achieves rapid convergence. The authors further extend the framework with ADMM to enable batching across data, and demonstrate improved performance on the Van der Pol oscillator compared with traditional sequential training pipelines. The results indicate the approach can yield faster, more data-efficient training and support scalable multi-batch training with potential for compact model representations.

Abstract

Neural Ordinary Differential Equations (Neural ODEs) represent continuous-time dynamics with neural networks, offering advancements for modeling and control tasks. However, training Neural ODEs requires solving differential equations at each epoch, leading to high computational costs. This work investigates simultaneous optimization methods as a faster training alternative. In particular, we employ a collocation-based, fully discretized formulation and use IPOPT--a solver for large-scale nonlinear optimization--to simultaneously optimize collocation coefficients and neural network parameters. Using the Van der Pol Oscillator as a case study, we demonstrate faster convergence compared to traditional training methods. Furthermore, we introduce a decomposition framework utilizing Alternating Direction Method of Multipliers (ADMM) to effectively coordinate sub-models among data batches. Our results show significant potential for (collocation-based) simultaneous Neural ODE training pipelines.

Paper Structure

This paper contains 21 sections, 23 equations, 5 figures, 1 table.

Figures (5)

  • Figure 1: Predictions of a model trained with collocation-based method (Pyomo) for the Van der Pol Oscillator. 200 training points and 200 testing points.
  • Figure 2: Training MSE for three training frameworks. Note that the MSE values at intermediate training times are obtained by interrupting IPOPT’s runtime and may appear unstable.
  • Figure 3: Testing MSE for three training frameworks.
  • Figure 4: Training MSE of JAX model with pre-training.
  • Figure 5: ADMM (150 + 150 training points) vs Single Model (300 training points) Performance