Table of Contents
Fetching ...

Convergence and Sketching-Based Efficient Computation of Neural Tangent Kernel Weights in Physics-Based Loss

Max Hirsch, Federico Pichi

TL;DR

The paper studies convergence of gradient descent in physics-informed neural networks when loss weights are adaptively determined by neural tangent kernel (NTK) analytics, and addresses computational bottlenecks with a randomized predictor-corrector, sketch-based NTK estimator. It proves that, under reasonable assumptions, the average squared residuals $\tfrac{1}{T}\sum_{t=0}^{T-1} \|\mathcal{R}(\theta_t)\|^2$ and the average gradient norms converge to zero, even as the inner product $\Lambda(\theta_t)$ evolves. To enable frequent NTK-based weighting without prohibitive cost, it introduces a fast NTK estimation scheme using matrix sketching; a moving-average scheme further reduces overhead to roughly two extra network evaluations per step, with unbiasedness up to discretization $Δt$. Numerical experiments on a wave-equation PINN and a nonlinear Q-tensor PINN corroborate the theory and demonstrate the practicality of the approach, including favorable comparisons to exact NTK-based weights and FEM baselines.

Abstract

In multi-objective optimization, multiple loss terms are weighted and added together to form a single objective. These weights are chosen to properly balance the competing losses according to some meta-goal. For example, in physics-informed neural networks (PINNs), these weights are often adaptively chosen to improve the network's generalization error. A popular choice of adaptive weights is based on the neural tangent kernel (NTK) of the PINN, which describes the evolution of the network in predictor space during training. The convergence of such an adaptive weighting algorithm is not clear a priori. Moreover, these NTK-based weights would be updated frequently during training, further increasing the computational burden of the learning process. In this paper, we prove that under appropriate conditions, gradient descent enhanced with adaptive NTK-based weights is convergent in a suitable sense. We then address the problem of computational efficiency by developing a randomized algorithm inspired by a predictor-corrector approach and matrix sketching, which produces unbiased estimates of the NTK up to an arbitrarily small discretization error. Finally, we provide numerical experiments to support our theoretical findings and to show the efficacy of our randomized algorithm. Code Availability: https://github.com/maxhirsch/Efficient-NTK

Convergence and Sketching-Based Efficient Computation of Neural Tangent Kernel Weights in Physics-Based Loss

TL;DR

The paper studies convergence of gradient descent in physics-informed neural networks when loss weights are adaptively determined by neural tangent kernel (NTK) analytics, and addresses computational bottlenecks with a randomized predictor-corrector, sketch-based NTK estimator. It proves that, under reasonable assumptions, the average squared residuals and the average gradient norms converge to zero, even as the inner product evolves. To enable frequent NTK-based weighting without prohibitive cost, it introduces a fast NTK estimation scheme using matrix sketching; a moving-average scheme further reduces overhead to roughly two extra network evaluations per step, with unbiasedness up to discretization . Numerical experiments on a wave-equation PINN and a nonlinear Q-tensor PINN corroborate the theory and demonstrate the practicality of the approach, including favorable comparisons to exact NTK-based weights and FEM baselines.

Abstract

In multi-objective optimization, multiple loss terms are weighted and added together to form a single objective. These weights are chosen to properly balance the competing losses according to some meta-goal. For example, in physics-informed neural networks (PINNs), these weights are often adaptively chosen to improve the network's generalization error. A popular choice of adaptive weights is based on the neural tangent kernel (NTK) of the PINN, which describes the evolution of the network in predictor space during training. The convergence of such an adaptive weighting algorithm is not clear a priori. Moreover, these NTK-based weights would be updated frequently during training, further increasing the computational burden of the learning process. In this paper, we prove that under appropriate conditions, gradient descent enhanced with adaptive NTK-based weights is convergent in a suitable sense. We then address the problem of computational efficiency by developing a randomized algorithm inspired by a predictor-corrector approach and matrix sketching, which produces unbiased estimates of the NTK up to an arbitrarily small discretization error. Finally, we provide numerical experiments to support our theoretical findings and to show the efficacy of our randomized algorithm. Code Availability: https://github.com/maxhirsch/Efficient-NTK

Paper Structure

This paper contains 18 sections, 8 theorems, 89 equations, 10 figures, 3 algorithms.

Key Result

Lemma 1

If $F \in C^1(\mathcal{P})$ with $\mathcal{P}\subseteq\mathbb{R}^p$ convex and $\nabla F$ is $L$-Lipschitz, then for all $x,y\in\mathcal{P}$,

Figures (10)

  • Figure 1: Convergence behavior for the total training loss, composed by equal contribution from the mean PDE residual, and the mean boundary terms.
  • Figure 2: Loss weights and NTK eigenvalues which verify Assumptions \ref{['assumption:lambda-bounds']} and \ref{['assumption:K-bounds']}.
  • Figure 3: Convergence rates of the time averaged loss and the time averaged norm of gradient of the residuals.
  • Figure 4: Quadratically parameterized predictor, sampled data, and true output.
  • Figure 5: Approximation of NTK at a fixed time for differing numbers of samples.
  • ...and 5 more figures

Theorems & Definitions (24)

  • Remark 1
  • Remark 2
  • Lemma 1: Descent Lemma
  • proof
  • Theorem 1: Convergence of Residual Averages
  • proof
  • Remark 3
  • Theorem 2: Convergence of Gradient Averages
  • proof
  • Corollary 1: Convergence for NTK-Based Weights
  • ...and 14 more