Table of Contents
Fetching ...

Meta Curvature-Aware Minimization for Domain Generalization

Ziyang Chen, Yiwen Ye, Feilong Tang, Yongsheng Pan, Yong Xia

TL;DR

This work tackles domain generalization by addressing limitations of sharpness-based methods like SAM, introducing a curvature-aware training paradigm. It defines a loss-agnostic curvature metric and derives Meta Curvature-Aware Minimization (MeCAM), which minimizes the vanilla loss and the surrogate gaps of SAM and meta-learning to locate flatter minima. Theoretical results include a PAC-Bayesian generalization bound and a convergence rate of $O\left(\frac{\log T}{\sqrt{T}}\right)$ for MeCAM in non-convex stochastic optimization. Empirically, MeCAM achieves superior generalization on five DG benchmarks and demonstrates good extensibility and flatter loss landscapes, with code to be released on GitHub.

Abstract

Domain generalization (DG) aims to enhance the ability of models trained on source domains to generalize effectively to unseen domains. Recently, Sharpness-Aware Minimization (SAM) has shown promise in this area by reducing the sharpness of the loss landscape to obtain more generalized models. However, SAM and its variants sometimes fail to guide the model toward a flat minimum, and their training processes exhibit limitations, hindering further improvements in model generalization. In this paper, we first propose an improved model training process aimed at encouraging the model to converge to a flat minima. To achieve this, we design a curvature metric that has a minimal effect when the model is far from convergence but becomes increasingly influential in indicating the curvature of the minima as the model approaches a local minimum. Then we derive a novel algorithm from this metric, called Meta Curvature-Aware Minimization (MeCAM), to minimize the curvature around the local minima. Specifically, the optimization objective of MeCAM simultaneously minimizes the regular training loss, the surrogate gap of SAM, and the surrogate gap of meta-learning. We provide theoretical analysis on MeCAM's generalization error and convergence rate, and demonstrate its superiority over existing DG methods through extensive experiments on five benchmark DG datasets, including PACS, VLCS, OfficeHome, TerraIncognita, and DomainNet. Code will be available on GitHub.

Meta Curvature-Aware Minimization for Domain Generalization

TL;DR

This work tackles domain generalization by addressing limitations of sharpness-based methods like SAM, introducing a curvature-aware training paradigm. It defines a loss-agnostic curvature metric and derives Meta Curvature-Aware Minimization (MeCAM), which minimizes the vanilla loss and the surrogate gaps of SAM and meta-learning to locate flatter minima. Theoretical results include a PAC-Bayesian generalization bound and a convergence rate of for MeCAM in non-convex stochastic optimization. Empirically, MeCAM achieves superior generalization on five DG benchmarks and demonstrates good extensibility and flatter loss landscapes, with code to be released on GitHub.

Abstract

Domain generalization (DG) aims to enhance the ability of models trained on source domains to generalize effectively to unseen domains. Recently, Sharpness-Aware Minimization (SAM) has shown promise in this area by reducing the sharpness of the loss landscape to obtain more generalized models. However, SAM and its variants sometimes fail to guide the model toward a flat minimum, and their training processes exhibit limitations, hindering further improvements in model generalization. In this paper, we first propose an improved model training process aimed at encouraging the model to converge to a flat minima. To achieve this, we design a curvature metric that has a minimal effect when the model is far from convergence but becomes increasingly influential in indicating the curvature of the minima as the model approaches a local minimum. Then we derive a novel algorithm from this metric, called Meta Curvature-Aware Minimization (MeCAM), to minimize the curvature around the local minima. Specifically, the optimization objective of MeCAM simultaneously minimizes the regular training loss, the surrogate gap of SAM, and the surrogate gap of meta-learning. We provide theoretical analysis on MeCAM's generalization error and convergence rate, and demonstrate its superiority over existing DG methods through extensive experiments on five benchmark DG datasets, including PACS, VLCS, OfficeHome, TerraIncognita, and DomainNet. Code will be available on GitHub.

Paper Structure

This paper contains 32 sections, 47 equations, 5 figures, 16 tables, 1 algorithm.

Figures (5)

  • Figure 1: Accuracy of our MeCAM and existing DG methods on the PACS and VLCS datasets. MeCAM achieves superior generalization performance on both datasets.
  • Figure 2: Comparison between SAM and the proposed curvature metric. (a) Illustration of a sharp local minimum at $\theta_1$ (red) and a flat local minimum at $\theta_2$ (blue), where $S(\theta)$ denotes the sharpness of model with parameter $\theta$. While SAM prefers to optimize around $\theta_2^{sam}$, $\theta_2$ is inherently flatter than $\theta_1$. This highlights SAM's limitation in accurately characterizing sharpness, as it is influenced by the loss value. (b) Illustration of the proposed curvature metric $\mathcal{C}$. Unlike SAM, $\mathcal{C}$ more effectively quantifies the deviation from a flat surface, providing a better measure of sharpness in the loss landscape. A smaller value of $\mathcal{C}$ indicates a flatter minimum. Best viewed in color.
  • Figure 3: Accuracy of our MeCAM and other sharpness-based methods across various training iterations using each domain of PACS as the target domain. For each figure, the X-axis denotes the training iterations, where $5,000$ is the default configuration, and the Y-axis indicates the accuracy.
  • Figure 4: Visualization of the loss landscapes of ERM, SAM, and our MeCAM on the PACS dataset. The number represents the loss value.
  • Figure I: Accuracy of our MeCAM with various $\rho$, $\alpha$, and $\beta$ on the PACS dataset. (a) We evaluated our MeCAM using different $\rho$. (b) We conducted the grid search for $\alpha$ and $\beta$.