$\texttt{immrax}$: A Parallelizable and Differentiable Toolbox for Interval Analysis and Mixed Monotone Reachability in JAX
Akash Harapanahalli, Saber Jafarpour, Samuel Coogan
TL;DR
The paper presents immrax, a differentiable, parallelizable toolbox for interval analysis and mixed monotone reachability implemented as JAX function transforms, enabling efficient bound propagation over input intervals $[\underline{x},\overline{x}]$ with $\mathsf{F}=\underline{\mathsf{F}}\overline{\mathsf{F}}$. It introduces composable transforms (\(\text{natif}, \text{jacif}, \text{mjacif}\)) and a novel Jacobian-based mixed bound for embedding-based reachability, along with an embedding framework that propagates over-approximations via $\dot{\underline{x}}$ and $\dot{\overline{x}}$. The framework is demonstrated on two case studies: a nonlinear vehicle controlled by a neural network and a robust pendulum controller, showing GPU-accelerated reachability and AD-driven robust optimal control, respectively. By combining interval theory, differentiable programming, and parallel hardware, immrax advances practical, differentiable, and scalable certified analysis for learning-enabled control systems.
Abstract
We present an implementation of interval analysis and mixed monotone interval reachability analysis as function transforms in Python, fully composable with the computational framework JAX. The resulting toolbox inherits several key features from JAX, including computational efficiency through Just-In-Time Compilation, GPU acceleration for quick parallelized computations, and Automatic Differentiability. We demonstrate the toolbox's performance on several case studies, including a reachability problem on a vehicle model controlled by a neural network, and a robust closed-loop optimal control problem for a swinging pendulum.
