Table of Contents
Fetching ...

Learning Deep Hybrid Models with Sharpness-Aware Minimization

Naoya Takeishi

TL;DR

Hybrid modeling that couples scientific equations with neural nets risks the neural part absorbing all predictive power. The authors adapt Sharpness-Aware Minimization (SAM) to hybrid models, perturbing only the neural component to enforce flat loss minima and encourage the scientific model to be utilized. Across six synthetic and real datasets, SAM-variants improve identification of scientific parameters while maintaining competitive predictive performance, without requiring architecture-specific regularizers. The approach offers a broadly applicable, regulator-friendly path to robust, interpretable hybrid modeling, though theoretical identifiability guarantees remain an open question.

Abstract

Hybrid modeling, the combination of machine learning models and scientific mathematical models, enables flexible and robust data-driven prediction with partial interpretability. However, effectively the scientific models may be ignored in prediction due to the flexibility of the machine learning model, making the idea of hybrid modeling pointless. Typically some regularization is applied to hybrid model learning to avoid such a failure case, but the formulation of the regularizer strongly depends on model architectures and domain knowledge. In this paper, we propose to focus on the flatness of loss minima in learning hybrid models, aiming to make the model as simple as possible. We employ the idea of sharpness-aware minimization and adapt it to the hybrid modeling setting. Numerical experiments show that the SAM-based method works well across different choices of models and datasets.

Learning Deep Hybrid Models with Sharpness-Aware Minimization

TL;DR

Hybrid modeling that couples scientific equations with neural nets risks the neural part absorbing all predictive power. The authors adapt Sharpness-Aware Minimization (SAM) to hybrid models, perturbing only the neural component to enforce flat loss minima and encourage the scientific model to be utilized. Across six synthetic and real datasets, SAM-variants improve identification of scientific parameters while maintaining competitive predictive performance, without requiring architecture-specific regularizers. The approach offers a broadly applicable, regulator-friendly path to robust, interpretable hybrid modeling, though theoretical identifiability guarantees remain an open question.

Abstract

Hybrid modeling, the combination of machine learning models and scientific mathematical models, enables flexible and robust data-driven prediction with partial interpretability. However, effectively the scientific models may be ignored in prediction due to the flexibility of the machine learning model, making the idea of hybrid modeling pointless. Typically some regularization is applied to hybrid model learning to avoid such a failure case, but the formulation of the regularizer strongly depends on model architectures and domain knowledge. In this paper, we propose to focus on the flatness of loss minima in learning hybrid models, aiming to make the model as simple as possible. We employ the idea of sharpness-aware minimization and adapt it to the hybrid modeling setting. Numerical experiments show that the SAM-based method works well across different choices of models and datasets.
Paper Structure (51 sections, 39 equations, 6 figures, 1 table, 1 algorithm)

This paper contains 51 sections, 39 equations, 6 figures, 1 table, 1 algorithm.

Figures (6)

  • Figure 1: Illustrations of (a) a prior distribution $p(\theta,\phi)$; (b) a posterior distribution $q(\theta,\phi)$ with nonnegligible mutual information between $\theta$ and $\phi$ (i.e., the last term of \ref{['eq:kl_decomposed']}); and (c) a posterior distribution with small mutual information. $\theta$ and $\phi$ are much higher dimensional in reality.
  • Figure 2: Update step of \ref{['alg:main']}, at the $k$-th iteration.
  • Figure 3: Data and predictions in the light tunnel task: (top row) ground truth images; (middle two rows) predictions from the erm model; and (bottom two rows) predictions from the fsam model. For each model, the upper row shows the predictions $\tilde{z}$ from the scientific model part (model F2), and the lower row shows the full predictions $\tilde{y}$.
  • Figure 4: (b) Effects of the hyperparameter configurations. The dashed lines represent the average performance of erm.
  • Figure 5: Sensitivity of performance with regard to the hyperparameter configurations.
  • ...and 1 more figures