Modality-Aware SAM: Sharpness-Aware-Minimization Driven Gradient Modulation for Harmonized Multimodal Learning
Hossein R. Nowdeh, Jie Ji, Xiaolong Ma, Fatemeh Afghah
TL;DR
This work tackles the imbalance problem in multimodal learning by introducing Modality-Aware SAM (M-SAM), an optimizer-level framework that aligns sharpness-aware updates with the dominant modality per batch. By decomposing the per-batch loss via Shapley values and applying a targeted SAM perturbation only to the dominant modality while allowing others to explore, M-SAM promotes flatter minima and robust convergence without altering model architecture or requiring fixed loss weights. The approach is validated across four multimodal datasets in both early and late fusion setups, where M-SAM consistently surpasses state-of-the-art gradient-modulation baselines and demonstrates improved generalization and stability. The results suggest that optimization-centered strategies can effectively harmonize modality contributions and enhance multimodal learning in practice, with potential broad impact on multimodal recognition, fusion strategies, and robust training.
Abstract
In multimodal learning, dominant modalities often overshadow others, limiting generalization. We propose Modality-Aware Sharpness-Aware Minimization (M-SAM), a model-agnostic framework that applies to many modalities and supports early and late fusion scenarios. In every iteration, M-SAM in three steps optimizes learning. \textbf{First, it identifies the dominant modality} based on modalities' contribution in the accuracy using Shapley. \textbf{Second, it decomposes the loss landscape}, or in another language, it modulates the loss to prioritize the robustness of the model in favor of the dominant modality, and \textbf{third, M-SAM updates the weights} by backpropagation of modulated gradients. This ensures robust learning for the dominant modality while enhancing contributions from others, allowing the model to explore and exploit complementary features that strengthen overall performance. Extensive experiments on four diverse datasets show that M-SAM outperforms the latest state-of-the-art optimization and gradient manipulation methods and significantly balances and improves multimodal learning.
