Table of Contents
Fetching ...

An Expert's Guide to Training Physics-informed Neural Networks

Sifan Wang, Shyam Sankaran, Hanwen Wang, Paris Perdikaris

TL;DR

This work tackles training fragilities of PINNs in solving forward and inverse PDE problems. It presents a cohesive pipeline that combines non-dimensionalization, Fourier feature embeddings, random weight factorization, and training strategies like causal weighting, adaptive loss balancing, and curriculum training. Through extensive, fully reproducible ablations on diverse benchmarks, it demonstrates state-of-the-art accuracy and reliability improvements. A high-performance JAX library accompanies the method to enable replication and adaptation to real-world scenarios.

Abstract

Physics-informed neural networks (PINNs) have been popularized as a deep learning framework that can seamlessly synthesize observational data and partial differential equation (PDE) constraints. Their practical effectiveness however can be hampered by training pathologies, but also oftentimes by poor choices made by users who lack deep learning expertise. In this paper we present a series of best practices that can significantly improve the training efficiency and overall accuracy of PINNs. We also put forth a series of challenging benchmark problems that highlight some of the most prominent difficulties in training PINNs, and present comprehensive and fully reproducible ablation studies that demonstrate how different architecture choices and training strategies affect the test accuracy of the resulting models. We show that the methods and guiding principles put forth in this study lead to state-of-the-art results and provide strong baselines that future studies should use for comparison purposes. To this end, we also release a highly optimized library in JAX that can be used to reproduce all results reported in this paper, enable future research studies, as well as facilitate easy adaptation to new use-case scenarios.

An Expert's Guide to Training Physics-informed Neural Networks

TL;DR

This work tackles training fragilities of PINNs in solving forward and inverse PDE problems. It presents a cohesive pipeline that combines non-dimensionalization, Fourier feature embeddings, random weight factorization, and training strategies like causal weighting, adaptive loss balancing, and curriculum training. Through extensive, fully reproducible ablations on diverse benchmarks, it demonstrates state-of-the-art accuracy and reliability improvements. A high-performance JAX library accompanies the method to enable replication and adaptation to real-world scenarios.

Abstract

Physics-informed neural networks (PINNs) have been popularized as a deep learning framework that can seamlessly synthesize observational data and partial differential equation (PDE) constraints. Their practical effectiveness however can be hampered by training pathologies, but also oftentimes by poor choices made by users who lack deep learning expertise. In this paper we present a series of best practices that can significantly improve the training efficiency and overall accuracy of PINNs. We also put forth a series of challenging benchmark problems that highlight some of the most prominent difficulties in training PINNs, and present comprehensive and fully reproducible ablation studies that demonstrate how different architecture choices and training strategies affect the test accuracy of the resulting models. We show that the methods and guiding principles put forth in this study lead to state-of-the-art results and provide strong baselines that future studies should use for comparison purposes. To this end, we also release a highly optimized library in JAX that can be used to reproduce all results reported in this paper, enable future research studies, as well as facilitate easy adaptation to new use-case scenarios.
Paper Structure (28 sections, 2 theorems, 67 equations, 25 figures, 13 tables, 2 algorithms)

This paper contains 28 sections, 2 theorems, 67 equations, 25 figures, 13 tables, 2 algorithms.

Key Result

Theorem B.1

Suppose that $\mathcal{L}(\mathbf{\theta})$ is the associated loss function of a neural network defined in eq: mlp_1 and eq: mlp_2. For a given $\mathbf{\theta}$, we define $U_{\mathbf{\theta}}$ as the set containing all possible weight factorizations Then for any $\mathbf{\theta}, \mathbf{\theta}'$, we have

Figures (25)

  • Figure 1: Illustration of the proposed training pipeline. The procedure begins with the non-dimensionalization of the PDE system, ensuring that input and output variables are in a reasonable range. Subsequently, an appropriate network architecture is constructed to represent the unknown PDE solution. The use of Fourier feature embeddings and random weight factorization is highly recommended for mitigating spectral bias and accelerating convergence. The training phase of the PINN model integrates various advanced algorithms, including self-adaptive loss balancing, causal training , and curriculum training.
  • Figure 2: Efficiency of weak scaling using the Navier-Stokes flow (section \ref{['sec: ns_tori']}) as a benchmark. We employ a neural network with hyper-parameters shown in Table \ref{['tab: ns_tori_config']} and measure the execution time for 10,000 iterations, maintaining a consistent batch size of 40960 per GPU.
  • Figure 3: Allen Cahn equation: Analysis of training a plain PINN model for $10,000$ iterations. Top left: Histograms of back-propagated gradients of the PDE residual loss and initial condition loss at the last iteration. Top right: Temporal PDE residual loss at the last iteration. Bottom: NTK eigenvalues of $K_{ic}$ and $K_r$ at the last iteration.
  • Figure 4: Allen Cahn equation: Convergence of relative $L^2$ error for the ablation study with different components disabled. Plain: Conventional PINN formulation. Default: PINN model trained using Algorithm \ref{['alg: pipline']}. No RWF: PINN model trained using Algorithm \ref{['alg: pipline']} without random weight factorization. No Grad Norm: PINN model trained using Algorithm \ref{['alg: pipline']} without grad norm weighting scheme. No Fourier feature: PINN model trained using Algorithm \ref{['alg: pipline']} without random Fourier feature embeddings. No Causal: PINN model trained using Algorithm \ref{['alg: pipline']} without casual weighting.
  • Figure 5: Allen Cahn equation: Comparison of the best prediction against the reference solution. The resulting relative $L^2$ error is $5.37 \times 10^{-5}$. The hyper-parameter configuration can be found in Table \ref{['tab: ac_config']}.
  • ...and 20 more figures

Theorems & Definitions (4)

  • Theorem B.1
  • proof
  • Theorem B.2
  • proof