Table of Contents
Fetching ...

MSPT: Efficient Large-Scale Physical Modeling via Parallelized Multi-Scale Attention

Pedro M. P. Curvo, Jan-Willem van de Meent, Maksim Zhdanov

TL;DR

The paper tackles the scalability gap in neural PDE solvers for industrial-scale simulations by introducing MSPT, a multi-scale patch transformer. It combines local patch self-attention with global context via pooled supernodes using ball-tree partitioning to handle irregular geometries, achieving near-linear scaling with the number of points. MSPT delivers state-of-the-art accuracy on standard PDE benchmarks and large CFD datasets (ShapeNet-Car, AhmedML) while reducing memory and compute costs compared to prior transformer-based solvers. The work includes thorough ablations on patch count and pooling, plus efficiency analyses, highlighting MSPT's potential for million-point on-device inference and real-time design optimization.

Abstract

A key scalability challenge in neural solvers for industrial-scale physics simulations is efficiently capturing both fine-grained local interactions and long-range global dependencies across millions of spatial elements. We introduce the Multi-Scale Patch Transformer (MSPT), an architecture that combines local point attention within patches with global attention to coarse patch-level representations. To partition the input domain into spatially-coherent patches, we employ ball trees, which handle irregular geometries efficiently. This dual-scale design enables MSPT to scale to millions of points on a single GPU. We validate our method on standard PDE benchmarks (elasticity, plasticity, fluid dynamics, porous flow) and large-scale aerodynamic datasets (ShapeNet-Car, Ahmed-ML), achieving state-of-the-art accuracy with substantially lower memory footprint and computational cost.

MSPT: Efficient Large-Scale Physical Modeling via Parallelized Multi-Scale Attention

TL;DR

The paper tackles the scalability gap in neural PDE solvers for industrial-scale simulations by introducing MSPT, a multi-scale patch transformer. It combines local patch self-attention with global context via pooled supernodes using ball-tree partitioning to handle irregular geometries, achieving near-linear scaling with the number of points. MSPT delivers state-of-the-art accuracy on standard PDE benchmarks and large CFD datasets (ShapeNet-Car, AhmedML) while reducing memory and compute costs compared to prior transformer-based solvers. The work includes thorough ablations on patch count and pooling, plus efficiency analyses, highlighting MSPT's potential for million-point on-device inference and real-time design optimization.

Abstract

A key scalability challenge in neural solvers for industrial-scale physics simulations is efficiently capturing both fine-grained local interactions and long-range global dependencies across millions of spatial elements. We introduce the Multi-Scale Patch Transformer (MSPT), an architecture that combines local point attention within patches with global attention to coarse patch-level representations. To partition the input domain into spatially-coherent patches, we employ ball trees, which handle irregular geometries efficiently. This dual-scale design enables MSPT to scale to millions of points on a single GPU. We validate our method on standard PDE benchmarks (elasticity, plasticity, fluid dynamics, porous flow) and large-scale aerodynamic datasets (ShapeNet-Car, Ahmed-ML), achieving state-of-the-art accuracy with substantially lower memory footprint and computational cost.

Paper Structure

This paper contains 51 sections, 22 equations, 6 figures, 8 tables.

Figures (6)

  • Figure 1: Parallelized Multi-Scale Attention mechanism. Each patch performs local self-attention, while pooled supernodes exchange information globally across patches in parallel. Peak memory (GB) and latency (ms) on 500k points with 256 slices (Transolver) and 256 patches (MSPT).
  • Figure 2: MSPT-Block. Each block partitions the point set into patches and pools local information into a small set of supernodes (here, 1). Multi-head attention is applied within each patch, augmented by the shared supernodes (global context).
  • Figure 3: Examples of relative L2 error maps for the Pipe, Navier-Stokes and ShapeNet Car datasets. For ShapeNet Car we show surface-pressure errors. See Appendix \ref{['app:app_pde_benchmark']} for more visualizations.
  • Figure 4: Study of pooling method and number of supernode tokens $Q$ per patch. Mean pooling consistently outperforms max pooling and learned linear projection for various $Q$ values. Increasing $Q$ generally lowers the validation loss.
  • Figure 5: Peak GPU memory usage (top) and wall-clock runtime per forward pass (bottom) as a function of the number of patches, across several input resolutions. Colors correspond to the input resolution (total number of points), as indicated by the color bar.
  • ...and 1 more figures