Table of Contents
Fetching ...

From Molecules to Materials: Pre-training Large Generalizable Models for Atomic Property Prediction

Nima Shoghi, Adeesh Kolluru, John R. Kitchin, Zachary W. Ulissi, C. Lawrence Zitnick, Brandon M. Wood

TL;DR

JMP introduces a supervised multi-task pre-training framework that aggregates ~120M atomic structures from diverse chemical domains to learn transferable atomic representations. Using a single GemNet-OC backbone with per-dataset heads, temperature-based sampling, and structure-wise loss balancing, JMP achieves a 59% average improvement over training from scratch and attains state-of-the-art or competitive performance on 34 of 40 downstream tasks across QM9, MD17, MD22, SPICE, MatBench, and QMOF. The findings demonstrate the value of cross-domain pre-training for atomic property prediction and show that larger models benefit low-data tasks, albeit with substantial upfront compute. The work also provides a detailed ablation and cost analysis, highlighting the practical trade-offs and outlining directions for scaling and future backbone exploration.

Abstract

Foundation models have been transformational in machine learning fields such as natural language processing and computer vision. Similar success in atomic property prediction has been limited due to the challenges of training effective models across multiple chemical domains. To address this, we introduce Joint Multi-domain Pre-training (JMP), a supervised pre-training strategy that simultaneously trains on multiple datasets from different chemical domains, treating each dataset as a unique pre-training task within a multi-task framework. Our combined training dataset consists of $\sim$120M systems from OC20, OC22, ANI-1x, and Transition-1x. We evaluate performance and generalization by fine-tuning over a diverse set of downstream tasks and datasets including: QM9, rMD17, MatBench, QMOF, SPICE, and MD22. JMP demonstrates an average improvement of 59% over training from scratch, and matches or sets state-of-the-art on 34 out of 40 tasks. Our work highlights the potential of pre-training strategies that utilize diverse data to advance property prediction across chemical domains, especially for low-data tasks. Please visit https://nima.sh/jmp for further information.

From Molecules to Materials: Pre-training Large Generalizable Models for Atomic Property Prediction

TL;DR

JMP introduces a supervised multi-task pre-training framework that aggregates ~120M atomic structures from diverse chemical domains to learn transferable atomic representations. Using a single GemNet-OC backbone with per-dataset heads, temperature-based sampling, and structure-wise loss balancing, JMP achieves a 59% average improvement over training from scratch and attains state-of-the-art or competitive performance on 34 of 40 downstream tasks across QM9, MD17, MD22, SPICE, MatBench, and QMOF. The findings demonstrate the value of cross-domain pre-training for atomic property prediction and show that larger models benefit low-data tasks, albeit with substantial upfront compute. The work also provides a detailed ablation and cost analysis, highlighting the practical trade-offs and outlining directions for scaling and future backbone exploration.

Abstract

Foundation models have been transformational in machine learning fields such as natural language processing and computer vision. Similar success in atomic property prediction has been limited due to the challenges of training effective models across multiple chemical domains. To address this, we introduce Joint Multi-domain Pre-training (JMP), a supervised pre-training strategy that simultaneously trains on multiple datasets from different chemical domains, treating each dataset as a unique pre-training task within a multi-task framework. Our combined training dataset consists of 120M systems from OC20, OC22, ANI-1x, and Transition-1x. We evaluate performance and generalization by fine-tuning over a diverse set of downstream tasks and datasets including: QM9, rMD17, MatBench, QMOF, SPICE, and MD22. JMP demonstrates an average improvement of 59% over training from scratch, and matches or sets state-of-the-art on 34 out of 40 tasks. Our work highlights the potential of pre-training strategies that utilize diverse data to advance property prediction across chemical domains, especially for low-data tasks. Please visit https://nima.sh/jmp for further information.
Paper Structure (31 sections, 3 equations, 6 figures, 19 tables)

This paper contains 31 sections, 3 equations, 6 figures, 19 tables.

Figures (6)

  • Figure 1: An overview of the Joint Multi-domain Pre-training (JMP) method. Left: JMP's pre-training setup, where a single model is simultaneously trained on set of diverse pre-training datasets using multi-task learning. Center: JMP's fine-tuning process, where the pre-trained JMP backbone is equipped with new prediction heads and trained on downstream tasks. Right: t-SNE visualizations of JMP's node-level ($\tilde{h}$) embeddings for randomly selected structures from all datasets.
  • Figure 2: Relative performance improvement across all tasks of all fine-tuning datasets, in percentages, of (a) Scratch Large (GN-OC-L) over Scratch Small (GN-OC-S), (b) Fine-tuned Large (JMP-L) over Fine-tuned Small (JMP-S), and (c) Fine-tuned Large (JMP-L) over Scratch Large (GN-OC-L). GN-OC shows poor scaling to large models, a clear sign of overfitting, whereasJMP reverses this, exhibiting much improved scaling dynamics. JMP also consistently outperforms GN-OC across all domains, datasets, and targets. The shaded rectangles indicate the average relative performance across all tasks for each dataset. The exact percentages can be found in \ref{['app:exact_percentages']}
  • Figure 3: Relative improvement, over training from scratch, of different pre-training methods on QM9's $\epsilon_\text{LUMO}$ and $\epsilon_\text{HOMO}$.
  • Figure 4: The number of GPU hours, averaged for each dataset, required to train GN-OC-L to convergence and to fine-tune JMP-L to match GN-OC-L's performance. Overall, fine-tuning JMP-L was able to match GN-OC-L's performance in $\frac{1}{12}$ the time.
  • Figure 5: t-SNE visualizations of the node-level ($\tilde{h}$) and edge-level ($\tilde{m}$) JMP-L embeddings for randomly selected structures from all pre-training and fine-tuning development datasets. Each point represents a structure, and the color indicates the dataset from which the structure was sampled.
  • ...and 1 more figures