Table of Contents
Fetching ...

Modern, Efficient, and Differentiable Transport Equation Models using JAX: Applications to Population Balance Equations

Mohammed Alsubeihi, Arthur Jessop, Ben Moseley, Cláudio P. Fonte, Ashwin Kumar Rajagopalan

TL;DR

This work aims to prepare for planned Scientific Machine Learning (SciML) integration through a contemporary implementation of an existing PBE algorithm, one with computational efficiency and differentiability at the forefront, utilizing JAX, a cutting-edge library for accelerated computing.

Abstract

Population balance equation (PBE) models have potential to automate many engineering processes with far-reaching implications. In the pharmaceutical sector, crystallization model-based design can contribute to shortening excessive drug development timelines. Even so, two major barriers, typical of most transport equations, not just PBEs, have limited this potential. Notably, the time taken to compute a solution to these models with representative accuracy is frequently limiting. Likewise, the model construction process is often tedious and wastes valuable time, owing to the reliance on human expertise to guess constituent models from empirical data. Hybrid models promise to overcome both barriers through tight integration of neural networks with physical PBE models. Towards eliminating experimental guesswork, hybrid models facilitate determining physical relationships from data, also known as 'discovering physics'. Here, we aim to prepare for planned Scientific Machine Learning (SciML) integration through a contemporary implementation of an existing PBE algorithm, one with computational efficiency and differentiability at the forefront. To accomplish this, we utilized JAX, a cutting-edge library for accelerated computing. We showcase the speed benefits of this modern take on PBE modelling by benchmarking our solver to others we prepared using older, more widespread software. Primarily among these software tools is the ubiquitous NumPy, where we show JAX achieves up to 300x relative acceleration in PBE simulations. Our solver is also fully differentiable, which we demonstrate is the only feasible option for integrating learnable data-driven models at scale. We show that differentiability can be 40x faster for optimizing larger models than conventional approaches, which represents the key to neural network integration for physics discovery in later work.

Modern, Efficient, and Differentiable Transport Equation Models using JAX: Applications to Population Balance Equations

TL;DR

This work aims to prepare for planned Scientific Machine Learning (SciML) integration through a contemporary implementation of an existing PBE algorithm, one with computational efficiency and differentiability at the forefront, utilizing JAX, a cutting-edge library for accelerated computing.

Abstract

Population balance equation (PBE) models have potential to automate many engineering processes with far-reaching implications. In the pharmaceutical sector, crystallization model-based design can contribute to shortening excessive drug development timelines. Even so, two major barriers, typical of most transport equations, not just PBEs, have limited this potential. Notably, the time taken to compute a solution to these models with representative accuracy is frequently limiting. Likewise, the model construction process is often tedious and wastes valuable time, owing to the reliance on human expertise to guess constituent models from empirical data. Hybrid models promise to overcome both barriers through tight integration of neural networks with physical PBE models. Towards eliminating experimental guesswork, hybrid models facilitate determining physical relationships from data, also known as 'discovering physics'. Here, we aim to prepare for planned Scientific Machine Learning (SciML) integration through a contemporary implementation of an existing PBE algorithm, one with computational efficiency and differentiability at the forefront. To accomplish this, we utilized JAX, a cutting-edge library for accelerated computing. We showcase the speed benefits of this modern take on PBE modelling by benchmarking our solver to others we prepared using older, more widespread software. Primarily among these software tools is the ubiquitous NumPy, where we show JAX achieves up to 300x relative acceleration in PBE simulations. Our solver is also fully differentiable, which we demonstrate is the only feasible option for integrating learnable data-driven models at scale. We show that differentiability can be 40x faster for optimizing larger models than conventional approaches, which represents the key to neural network integration for physics discovery in later work.

Paper Structure

This paper contains 26 sections, 17 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: High-level flowchart of the high-level computation within a PBE algorithm, with operations that would benefit from GPU execution (green dashed region) and/or JIT compilation (orange dotted region) indicated.
  • Figure 2: Comparison of the concentration (left) and total crystal volume (right) profiles obtained by each FVM solver to those obtained by the MOM.
  • Figure 3: Computation time benchmarks for different solvers with varying computational load by varying: a) the number of time steps, and b) the spatial domain size. Both extremes for each varied parameter are marked by vertical dashed lines. The CUDA and JAX (GPU) solvers have access to a GPU, while the remaining solvers are CPU-based.
  • Figure 4: Comparison of the average iteration time during parameter estimation with increasing number of parameters for distinct differentiation algorithms: jax-AD, jax-ND, and slow-ND. The single marker indicates the time taken by the Bötschi 2019 solver botschiFeedbackControlSize2019 to update 4 parameters over a single iteration. Two regions have been highlighted, which indicate the number of parameters that are normally optimized for traditional (number of parameters $\lessapprox 20$) and machine learning (number of parameters $\gtrapprox 20$) applications.
  • Figure 5: Two conceptual illustrations of potential hybrid models applied to the algorithm from Figure \ref{['fig-pbe_algorithm']}: a) Empirical hybrid model, where a neural network replaces the calculation of the advection rate. b) In-the-loop model, where a neural network estimates the discretization error of the solver at each iteration of the loop.
  • ...and 3 more figures