Table of Contents
Fetching ...

Understanding and Mitigating Distribution Shifts For Machine Learning Force Fields

Tobias Kreiman, Aditi S. Krishnapriyan

TL;DR

This paper investigates why state-of-the-art MLFFs fail to generalize under distribution shifts across chemical space, even for large models trained on extensive data. It diagnoses shifts in atomic features, force norms, and graph connectivity, and introduces two test-time refinement strategies: test-time radius refinement (RR) to align test graph spectra with training graphs, and test-time training (TTT) with cheap priors to regularize representations without needing reference labels. Across SPICE/SPICEv2 and extreme-molecule benchmarks, RR and TTT reduce errors, improve MD stability, and substantially lower the data required for fine-tuning, suggesting MLFFs can generalize to more diverse chemistries when trained with distribution-shift-aware strategies. The work provides practical benchmarks and code to evaluate and advance the generalization capabilities of the next generation of MLFFs.

Abstract

Machine Learning Force Fields (MLFFs) are a promising alternative to expensive ab initio quantum mechanical molecular simulations. Given the diversity of chemical spaces that are of interest and the cost of generating new data, it is important to understand how MLFFs generalize beyond their training distributions. In order to characterize and better understand distribution shifts in MLFFs, we conduct diagnostic experiments on chemical datasets, revealing common shifts that pose significant challenges, even for large foundation models trained on extensive data. Based on these observations, we hypothesize that current supervised training methods inadequately regularize MLFFs, resulting in overfitting and learning poor representations of out-of-distribution systems. We then propose two new methods as initial steps for mitigating distribution shifts for MLFFs. Our methods focus on test-time refinement strategies that incur minimal computational cost and do not use expensive ab initio reference labels. The first strategy, based on spectral graph theory, modifies the edges of test graphs to align with graph structures seen during training. Our second strategy improves representations for out-of-distribution systems at test-time by taking gradient steps using an auxiliary objective, such as a cheap physical prior. Our test-time refinement strategies significantly reduce errors on out-of-distribution systems, suggesting that MLFFs are capable of and can move towards modeling diverse chemical spaces, but are not being effectively trained to do so. Our experiments establish clear benchmarks for evaluating the generalization capabilities of the next generation of MLFFs. Our code is available at https://tkreiman.github.io/projects/mlff_distribution_shifts/.

Understanding and Mitigating Distribution Shifts For Machine Learning Force Fields

TL;DR

This paper investigates why state-of-the-art MLFFs fail to generalize under distribution shifts across chemical space, even for large models trained on extensive data. It diagnoses shifts in atomic features, force norms, and graph connectivity, and introduces two test-time refinement strategies: test-time radius refinement (RR) to align test graph spectra with training graphs, and test-time training (TTT) with cheap priors to regularize representations without needing reference labels. Across SPICE/SPICEv2 and extreme-molecule benchmarks, RR and TTT reduce errors, improve MD stability, and substantially lower the data required for fine-tuning, suggesting MLFFs can generalize to more diverse chemistries when trained with distribution-shift-aware strategies. The work provides practical benchmarks and code to evaluate and advance the generalization capabilities of the next generation of MLFFs.

Abstract

Machine Learning Force Fields (MLFFs) are a promising alternative to expensive ab initio quantum mechanical molecular simulations. Given the diversity of chemical spaces that are of interest and the cost of generating new data, it is important to understand how MLFFs generalize beyond their training distributions. In order to characterize and better understand distribution shifts in MLFFs, we conduct diagnostic experiments on chemical datasets, revealing common shifts that pose significant challenges, even for large foundation models trained on extensive data. Based on these observations, we hypothesize that current supervised training methods inadequately regularize MLFFs, resulting in overfitting and learning poor representations of out-of-distribution systems. We then propose two new methods as initial steps for mitigating distribution shifts for MLFFs. Our methods focus on test-time refinement strategies that incur minimal computational cost and do not use expensive ab initio reference labels. The first strategy, based on spectral graph theory, modifies the edges of test graphs to align with graph structures seen during training. Our second strategy improves representations for out-of-distribution systems at test-time by taking gradient steps using an auxiliary objective, such as a cheap physical prior. Our test-time refinement strategies significantly reduce errors on out-of-distribution systems, suggesting that MLFFs are capable of and can move towards modeling diverse chemical spaces, but are not being effectively trained to do so. Our experiments establish clear benchmarks for evaluating the generalization capabilities of the next generation of MLFFs. Our code is available at https://tkreiman.github.io/projects/mlff_distribution_shifts/.

Paper Structure

This paper contains 67 sections, 4 theorems, 20 equations, 20 figures, 15 tables.

Key Result

Theorem 4.1

If the reference energy calculations asymptotically go to $\infty$ as pairwise distances go to $0$, then there exist test-time training inputs such that a gradient step on the prior loss, with the Lennard-Jones potential, reduces the main task loss on those inputs.

Figures (20)

  • Figure 1: Distribution Shifts for MLFFs. We visualize distribution shifts based on changes in features, labels, and graph structure. Typical training samples from SPICE Eastman2023spice and new systems from SPICEv2 eastman2024spice2 are displayed. An atomic feature shift is illustrated by comparing a three-atom molecule with a larger molecular system containing 91 atoms (left). A force norm shift is shown by the close proximity of an $H_2$ molecule (circled in pink), leading to high force norms (middle). A connectivity shift is shown by the tetrahedral geometry in $P_4S_6$, which differs from the typical planar geometry seen during training (right).
  • Figure 2: Distribution Shifts for Large Models. We study distribution shifts on four of the largest open-source MLFFs designed for broad chemical spaces. (a) We evaluate MACE-MP on the MPTrj train set. (b) We evaluate MACE-OFF on 10k new molecules from SPICEv2. (c) We evaluate EquiformerV2 on the OC20 out-of-distribution validation set. (d) We evaluate JMP on the ANI-1x test set. A molecule is considered out-of-distribution if it is more than 1 standard deviation away from the mean training force norm, system size, or connectivity (with respect to the spectral distance defined above §\ref{['sec:criteria_ds']}). Despite their scale, these large foundation models have $2-10\times$ larger force mean absolute errors (MAE) when encountering distribution shifts.
  • Figure 3: Test-Time Radius Refinement. MLFFs tend to overfit to the well-connected graphs seen during training, which can be identified by the clustering of Laplacian eigenvalues around 1. To mitigate connectivity distribution shifts at test time, we find the optimal radius cutoff, which aligns the Laplacian eigenvalues of test graphs with those of the training distribution.
  • Figure 4: Test-Time Training Mitigates Distribution Shifts and Smooths Predicted Potential Energy Surfaces. We hypothesize that due to overfitting, the predicted potential energy surfaces are jagged for out-of-distribution systems. Our proposed test-time training method (TTT, a) regularizes MLFFs by incorporating inductive biases into the model using a cheap prior. Test-time training first learns useful representations from the prior using either joint-training or a pre-train, freeze, and fine-tune approach. TTT then updates the representations at test-time using the prior to improve performance on out-of-distribution samples. We plot the predicted potential energy surface from a GemNet-dT model along the 2 principal components of the Hessian for salicylic acid, a molecule not seen during training, before and after test-time training (b). TTT effectively smooths the potential energy landscape and improves errors.
  • Figure 5: Test-Time Training Decreases the Amount of Fine-Tuning Data Needed to Match In-Distribution Performance. We fine-tune GemNet-T models, trained on SPICE, on new molecules from the SPICEv2 dataset. Applying TTT on the new data before fine-tuning decreases the amount of training data needed to match the in-distribution performance by $10\times$. Applying TTT before fine-tuning also decreases the final error by 25% when training on all the data.
  • ...and 15 more figures

Theorems & Definitions (5)

  • Theorem 4.1
  • Theorem B.1: TTT with a Lennard-Jones Prior Improves Performance on Quantum Mechanical Predictions
  • proof
  • Theorem B.2: Extrapolation to new graphs bechlerspeicher2024gnnregular
  • Theorem B.3: Extrapolation within regular graph distributions bechlerspeicher2024gnnregular