Table of Contents
Fetching ...

Self-supervised learning for crystal property prediction via denoising

Alexander New, Nam Q. Le, Michael J. Pekala, Christopher D. Stiles

TL;DR

CDSSL tackles the challenge of scarce labeled crystal-property data by pretraining graph-based crystal models on a denoising pretext that perturbs atomic positions and predicts original edge embeddings, encouraging a generalizable structure-space representation. The method integrates a multigraph crystal representation with an MEGNet backbone and Set2Set aggregation to transfer to diverse property-prediction tasks, yielding improved accuracy over non-SSL baselines. Empirical results show consistent gains across material classes, properties, and data regimes, including low-data settings, and reveal that the learned representation captures meaningful material variation as shown by linear probing and density-volume structure. This approach enables leveraging large unlabeled structural databases to boost targeted crystal-property predictions, with potential for richer physical insights via the learned potential-energy-informed space.

Abstract

Accurate prediction of the properties of crystalline materials is crucial for targeted discovery, and this prediction is increasingly done with data-driven models. However, for many properties of interest, the number of materials for which a specific property has been determined is much smaller than the number of known materials. To overcome this disparity, we propose a novel self-supervised learning (SSL) strategy for material property prediction. Our approach, crystal denoising self-supervised learning (CDSSL), pretrains predictive models (e.g., graph networks) with a pretext task based on recovering valid material structures when given perturbed versions of these structures. We demonstrate that CDSSL models out-perform models trained without SSL, across material types, properties, and dataset sizes.

Self-supervised learning for crystal property prediction via denoising

TL;DR

CDSSL tackles the challenge of scarce labeled crystal-property data by pretraining graph-based crystal models on a denoising pretext that perturbs atomic positions and predicts original edge embeddings, encouraging a generalizable structure-space representation. The method integrates a multigraph crystal representation with an MEGNet backbone and Set2Set aggregation to transfer to diverse property-prediction tasks, yielding improved accuracy over non-SSL baselines. Empirical results show consistent gains across material classes, properties, and data regimes, including low-data settings, and reveal that the learned representation captures meaningful material variation as shown by linear probing and density-volume structure. This approach enables leveraging large unlabeled structural databases to boost targeted crystal-property predictions, with potential for richer physical insights via the learned potential-energy-informed space.

Abstract

Accurate prediction of the properties of crystalline materials is crucial for targeted discovery, and this prediction is increasingly done with data-driven models. However, for many properties of interest, the number of materials for which a specific property has been determined is much smaller than the number of known materials. To overcome this disparity, we propose a novel self-supervised learning (SSL) strategy for material property prediction. Our approach, crystal denoising self-supervised learning (CDSSL), pretrains predictive models (e.g., graph networks) with a pretext task based on recovering valid material structures when given perturbed versions of these structures. We demonstrate that CDSSL models out-perform models trained without SSL, across material types, properties, and dataset sizes.
Paper Structure (11 sections, 6 equations, 5 figures, 6 tables)

This paper contains 11 sections, 6 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: We summarize . The node positions of a structure $G$ are perturbed with Gaussian noise to create a structure $\tilde{G}$. The model $h_\theta$ takes the perturbed structure $\tilde{G}$ as input and seeks to output the edge embeddings of the original structure $G$.
  • Figure 2: We summarize the application of our framework to a property prediction task. In the top row, a is trained to denoise crystal structures (\ref{['fig:ssl_schematic']} and eq. \ref{['eq:ssl_task']}) with predicted edge embeddings $\hat{e}_{v,v',k}$. Once the module has been trained, we can finetune it on property prediction. This entails passing the node-level outputs $\hat{x}_v$, edge-level outputs $\hat{u}_{v,v',k}$, and graph-level output $\hat{s}$ through Set2Set Vinyals2016set2set modules to output the predicted property $\hat{y}$.
  • Figure 3: We show metrics from pretraining a with the pretraining objective. Training yields slow but consistent decreases in both the training loss (eq. \ref{['eq:ssl_task']}) and the of the $h_\theta(\tilde{G}) - \bar{E}$ quantity (for both the training and validation set). for the training and validation set overlap, indicating that overfitting is not happening. The pretraining task retains instabilities during training, as evidenced by the jump in metrics and gradient norm of the loss at the end of training.
  • Figure 4: We demonstrate the effects of using vs. across a variety of datasets and dataset sizes. Each bar reports error on the evaluation set, averaged over $3$ data splits and network initializations, and error bars show standard errors in estimating that mean accuracy. The model finetuned after has a lower error than the model in $37$ out of $49$ (dataset, dataset size) configurations.
  • Figure 5: We use Mcinnes2020umap to learn a reduced representation of the matbench_mp_e_form dataset used for pretraining with (eq. \ref{['eq:representation']}). We shade points by their corresponding structure's density. Within the reduced representation, structures with similar densities are near each other. This suggests that the representation space learned via has captured general notions of material properties. Error metrics are reported in the unit of each dataset's property.