Table of Contents
Fetching ...

jinns: a JAX Library for Physics-Informed Neural Networks

Hugo Gangloff, Nicolas Jouvin

TL;DR

jinns is a JAX-native library for physics-informed neural networks that enables forward, inverse, and meta-modeling tasks by enforcing PDE residuals on collocation points with a global loss ${\mathcal L}(\nu,\theta)$. It separates neural-network parameters $\nu$ from equation-parameters $\theta$, and introduces a DerivativeKeys mechanism to selectively differentiate loss terms with respect to parameters, all within a pure-JAX, Equinox/Optax-enabled workflow. The framework provides standard PDE losses, vectorization via $jax.vmap$, and supports refined PINN architectures such as HyperPINN and SeparablePINN, along with extensible data, losses, and architectures. Benchmarking against DeepXDE, Modulus, and PINA on the PINNacle suite shows competitive performance, with particular strength in inverse problems, underscoring its practicality for rapid prototyping and reproducible scientific computing in a fully JAX ecosystem.

Abstract

jinns is an open-source Python library for physics-informed neural networks, built to tackle both forward and inverse problems, as well as meta-model learning. Rooted in the JAX ecosystem, it provides a versatile framework for efficiently prototyping real-problems, while easily allowing extensions to specific needs. Furthermore, the implementation leverages existing popular JAX libraries such as equinox and optax for model definition and optimisation, bringing a sense of familiarity to the user. Many models are available as baselines, and the documentation provides reference implementations of different use-cases along with step-by-step tutorials for extensions to specific needs. The code is available on Gitlab https://gitlab.com/mia_jinns/jinns.

jinns: a JAX Library for Physics-Informed Neural Networks

TL;DR

jinns is a JAX-native library for physics-informed neural networks that enables forward, inverse, and meta-modeling tasks by enforcing PDE residuals on collocation points with a global loss . It separates neural-network parameters from equation-parameters , and introduces a DerivativeKeys mechanism to selectively differentiate loss terms with respect to parameters, all within a pure-JAX, Equinox/Optax-enabled workflow. The framework provides standard PDE losses, vectorization via , and supports refined PINN architectures such as HyperPINN and SeparablePINN, along with extensible data, losses, and architectures. Benchmarking against DeepXDE, Modulus, and PINA on the PINNacle suite shows competitive performance, with particular strength in inverse problems, underscoring its practicality for rapid prototyping and reproducible scientific computing in a fully JAX ecosystem.

Abstract

jinns is an open-source Python library for physics-informed neural networks, built to tackle both forward and inverse problems, as well as meta-model learning. Rooted in the JAX ecosystem, it provides a versatile framework for efficiently prototyping real-problems, while easily allowing extensions to specific needs. Furthermore, the implementation leverages existing popular JAX libraries such as equinox and optax for model definition and optimisation, bringing a sense of familiarity to the user. Many models are available as baselines, and the documentation provides reference implementations of different use-cases along with step-by-step tutorials for extensions to specific needs. The code is available on Gitlab https://gitlab.com/mia_jinns/jinns.

Paper Structure

This paper contains 22 sections, 3 equations, 1 figure, 3 tables.

Figures (1)

  • Figure 1: Typical user workflow for jinns users.