Table of Contents
Fetching ...

JAX-LaB: A High-Performance, Differentiable, Lattice Boltzmann Library for Modeling Multiphase Fluid Dynamics in Geosciences and Engineering

Piyush Pradhan, Pierre Gentine, Shaina Kelly

TL;DR

JAX-LaB delivers a differentiable, high-performance LBM framework for multiphase and multicomponent flows in geosciences, integrating Shan-Chen pseudopotential with arbitrary EOSs and a pressure-tensor modification to decouple surface tension from density ratio, achieving densities with ratios $> 10^7$ while suppressing spurious currents. Built on Python and JAX, it supports single- and multi-GPU execution, scalable distributed runs, and seamless ML integration, enabling forward and inverse modeling in porous media and hydrology. The paper validates thermodynamic consistency, Laplace’s law, capillary dynamics, and porous-media flows (permeability, drainage in sandstone, and sphere-pack curves), and demonstrates strong GPU performance with detailed weak/strong scaling benchmarks. Open-source under the Apache license, JAX-LaB provides a modular, extensible platform for pore-scale simulations, differentiable modeling, and ML-assisted design in geoscience and engineering contexts.

Abstract

We introduce JAX-LaB, a differentiable, Python-based Lattice Boltzmann simulation library designed for modeling multiphase and multiphysics fluid dynamics problems in hydrologic, geologic, and engineered porous media settings. The library is designed as an extension to XLB, and it is built on the JAX framework. The library delivers a performant, hardware-agnostic implementation that seamlessly integrates with machine learning libraries and scales efficiently across CPUs, multi-GPU setups, and distributed environments. Multiphase interactions are modeled using the Shan-Chen pseudopotential method, coupled with an equation of state (EOS) to reproduce densities consistent with Maxwell's construction, enabling accurate simulation of flows with density ratios $> 10^7$ while maintaining low spurious currents. Fluid wetting is achieved using the "improved" virtual density scheme, which enables precise control of contact angle on flat and curved surfaces, while eliminating non-physical films seen in the Shan-Chen virtual density scheme. This scheme integrates directly into the interaction force calculations, removing the need to handle fluid-fluid and fluid-solid forces separately. We validate the library's accuracy and performance through comprehensive analytical benchmarks, including Laplace's law, capillary rise in parallel plates, and multi-component cocurrent flow in a channel. We then use the code for several applications involving multicomponent and multiphase flows, including permeability estimation, injection of supercritical $CO_2$ in a water-saturated Fontainebleau sandstone, and obtaining the characteristic curves for a sphere pack geometry. Finally, the single-GPU performance and multi-GPU scaling of the code are evaluated on both single-node and distributed systems. The library is open-source under the Apache license and available at https://github.com/piyush-ppradhan/JAX-LaB.

JAX-LaB: A High-Performance, Differentiable, Lattice Boltzmann Library for Modeling Multiphase Fluid Dynamics in Geosciences and Engineering

TL;DR

JAX-LaB delivers a differentiable, high-performance LBM framework for multiphase and multicomponent flows in geosciences, integrating Shan-Chen pseudopotential with arbitrary EOSs and a pressure-tensor modification to decouple surface tension from density ratio, achieving densities with ratios while suppressing spurious currents. Built on Python and JAX, it supports single- and multi-GPU execution, scalable distributed runs, and seamless ML integration, enabling forward and inverse modeling in porous media and hydrology. The paper validates thermodynamic consistency, Laplace’s law, capillary dynamics, and porous-media flows (permeability, drainage in sandstone, and sphere-pack curves), and demonstrates strong GPU performance with detailed weak/strong scaling benchmarks. Open-source under the Apache license, JAX-LaB provides a modular, extensible platform for pore-scale simulations, differentiable modeling, and ML-assisted design in geoscience and engineering contexts.

Abstract

We introduce JAX-LaB, a differentiable, Python-based Lattice Boltzmann simulation library designed for modeling multiphase and multiphysics fluid dynamics problems in hydrologic, geologic, and engineered porous media settings. The library is designed as an extension to XLB, and it is built on the JAX framework. The library delivers a performant, hardware-agnostic implementation that seamlessly integrates with machine learning libraries and scales efficiently across CPUs, multi-GPU setups, and distributed environments. Multiphase interactions are modeled using the Shan-Chen pseudopotential method, coupled with an equation of state (EOS) to reproduce densities consistent with Maxwell's construction, enabling accurate simulation of flows with density ratios while maintaining low spurious currents. Fluid wetting is achieved using the "improved" virtual density scheme, which enables precise control of contact angle on flat and curved surfaces, while eliminating non-physical films seen in the Shan-Chen virtual density scheme. This scheme integrates directly into the interaction force calculations, removing the need to handle fluid-fluid and fluid-solid forces separately. We validate the library's accuracy and performance through comprehensive analytical benchmarks, including Laplace's law, capillary rise in parallel plates, and multi-component cocurrent flow in a channel. We then use the code for several applications involving multicomponent and multiphase flows, including permeability estimation, injection of supercritical in a water-saturated Fontainebleau sandstone, and obtaining the characteristic curves for a sphere pack geometry. Finally, the single-GPU performance and multi-GPU scaling of the code are evaluated on both single-node and distributed systems. The library is open-source under the Apache license and available at https://github.com/piyush-ppradhan/JAX-LaB.

Paper Structure

This paper contains 25 sections, 26 equations, 17 figures, 3 tables.

Figures (17)

  • Figure 1: Illustration of computational sharding strategy and data structure employed to store JAX arrays and distribute them equally across multiple devices. Arrays are sliced along the x-axis and then divided across multiple devices. These sharded arrays are stored together in a list to form a pytree, with each JAX array storing values for one component.
  • Figure 2: Implementation of Shan-Chen force using pytrees in JAX-LaB. The colored squares indicate the different values of pytree, which can be either a floating-point value or a JAX array.
  • Figure 3: Schematic diagram for a droplet suspended in vapor.
  • Figure 4: Comparison of coexistence densities versus reduced temperature derived from Lattice Boltzmann simulations with predictions from the Maxwell construction for various equations of state (EOS): (a) Carnahan-Starling (CS) (b) Peng-Robinson (PR) (c) Redlich-Kwong (RK), (d) Redlich-Kwong-Soave (RKS) & (e) van der Waals (VdW). (f) Peak spurious current magnitudes across various temperatures for different equations of state. All values are in lattice units.
  • Figure 5: The relationship between pressure differential and droplet curvature (1/$R$) plotted for various values of $\kappa$, shown separately for (a) Carnahan-Starling (CS), (b) Peng-Robinson (PR), (c) Redlich-Kwong (RK), (d) Redlich-Kwong-Soave (RKS) and (e) van der Waals (VdW) equation of state at $\tau_v = 1.0$ and $T_r = 0.8$. All values are in lattice units.
  • ...and 12 more figures