Table of Contents
Fetching ...

Apax: A Flexible and Performant Framework For The Development of Machine-Learned Interatomic Potentials

Moritz René Schäfer, Nico Segreto, Fabian Zills, Christian Holm, Johannes Kästner

TL;DR

Atomic learned potentials in JAX (apax), a flexible and efficient open source software package for training and inference of machine-learned interatomic potentials, are introduced and a Gaussian Moment Neural Network model achieves higher accuracy and up to 10 times faster inference times than a performance-optimized Allegro model.

Abstract

We introduce Atomistic learned potentials in JAX (apax), a flexible and efficient open source software package for training and inference of machine-learned interatomic potentials. Built on the JAX framework, apax supports GPU acceleration and implements flexible model abstractions for fast development. With features such as kernel-based data selection, well-calibrated uncertainty estimation, and enhanced sampling, it is tailored to active learning applications and ease of use. The features and design decisions made in apax are discussed before demonstrating some of its capabilities. First, a data set for the room-temperature ionic liquid EMIM+BF4- is created using active learning. It is highlighted how continuously learning models between iterations can reduce training times up to 85 % with only a minor reduction of the models' accuracy. Second, we show good scalability in a data-parallel training setting. We report that a Gaussian Moment Neural Network model, as implemented in apax, achieves higher accuracy and up to 10 times faster inference times than a performance-optimized Allegro model. A recently published Li3PO4 dataset, reported with comparable accuracy and inference performance metrics, is used as a point of comparison. Moreover, the inference speeds of the available simulation engines are compared. Finally, to highlight the modularity of apax, an equivariant message-passing model is trained as a shallow ensemble and used to perform uncertainty-driven dynamics.

Apax: A Flexible and Performant Framework For The Development of Machine-Learned Interatomic Potentials

TL;DR

Atomic learned potentials in JAX (apax), a flexible and efficient open source software package for training and inference of machine-learned interatomic potentials, are introduced and a Gaussian Moment Neural Network model achieves higher accuracy and up to 10 times faster inference times than a performance-optimized Allegro model.

Abstract

We introduce Atomistic learned potentials in JAX (apax), a flexible and efficient open source software package for training and inference of machine-learned interatomic potentials. Built on the JAX framework, apax supports GPU acceleration and implements flexible model abstractions for fast development. With features such as kernel-based data selection, well-calibrated uncertainty estimation, and enhanced sampling, it is tailored to active learning applications and ease of use. The features and design decisions made in apax are discussed before demonstrating some of its capabilities. First, a data set for the room-temperature ionic liquid EMIM+BF4- is created using active learning. It is highlighted how continuously learning models between iterations can reduce training times up to 85 % with only a minor reduction of the models' accuracy. Second, we show good scalability in a data-parallel training setting. We report that a Gaussian Moment Neural Network model, as implemented in apax, achieves higher accuracy and up to 10 times faster inference times than a performance-optimized Allegro model. A recently published Li3PO4 dataset, reported with comparable accuracy and inference performance metrics, is used as a point of comparison. Moreover, the inference speeds of the available simulation engines are compared. Finally, to highlight the modularity of apax, an equivariant message-passing model is trained as a shallow ensemble and used to perform uncertainty-driven dynamics.

Paper Structure

This paper contains 22 sections, 10 equations, 11 figures, 1 table.

Figures (11)

  • Figure 1: Overview of the features and code structure of the package. Circles represent user-facing functionalities, rectangles internal feature groups, and diamond shapes data stored on disk.
  • Figure 2: Schematic representation of the model abstractions and interaction with other functionalities in . Inputs and model outputs added with each transformation are contained in the gray boxes.
  • Figure 3: Selection heuristic and analysis of a combined and geometry optimization trajectory. a) sorted squared distances used in the MaxDist selection method. b) Energy for each configuration of the combined trajectory. c) First two principal components of the last-layer features for training data and data pool. The pool of data points is marked in blue, selected configurations are marked in red, and training configurations are marked in grey in each subplot.
  • Figure 4: Evaluation metrics of CL150$_{10}$, CL300$_{10}$, CL500$_{10}$ and R1000$_{10}$ for a test set of the MACE-MP0 production trajectory.
  • Figure 5: a) Hydrogen-boron of trajectories produced with the final models CL150$_{10}$, CL300$_{10}$, CL500$_{10}$ and R1000$_{10}$ and the MACE-MP0 foundation model serving as ground truth. For better visibility, the are shifted with an offset of $0.2$. b) differences of the final models and MACE-MP0 foundation model.
  • ...and 6 more figures