Table of Contents
Fetching ...

Learning to Rebalance Multi-Modal Optimization by Adaptively Masking Subnetworks

Yang Yang, Hongpeng Pan, Qing-Yuan Jiang, Yi Xu, Jinghui Tang

TL;DR

The paper tackles modality imbalance in multi-modal learning by moving from global, modal-level gradient control to fine-grained, element-wise updates through Adaptively Mask Subnetworks (AMSS). By quantifying modal significance via a variational mutual information rate and focusing updates on selectively sampled subnetworks guided by Fisher information, AMSS rebalances learning across modalities; AMSS+ further introduces an unbiased masking scheme to improve convergence. The authors provide theoretical convergence analyses under biased and unbiased gradient scenarios and demonstrate substantial empirical gains across diverse datasets and backbones, including transformer-based architectures. The approach is presented as a versatile plug-in to existing multi-modal models, improving the utilization of non-dominant modalities with robust performance gains and sound theoretical guarantees.

Abstract

Multi-modal learning aims to enhance performance by unifying models from various modalities but often faces the "modality imbalance" problem in real data, leading to a bias towards dominant modalities and neglecting others, thereby limiting its overall effectiveness. To address this challenge, the core idea is to balance the optimization of each modality to achieve a joint optimum. Existing approaches often employ a modal-level control mechanism for adjusting the update of each modal parameter. However, such a global-wise updating mechanism ignores the different importance of each parameter. Inspired by subnetwork optimization, we explore a uniform sampling-based optimization strategy and find it more effective than global-wise updating. According to the findings, we further propose a novel importance sampling-based, element-wise joint optimization method, called Adaptively Mask Subnetworks Considering Modal Significance(AMSS). Specifically, we incorporate mutual information rates to determine the modal significance and employ non-uniform adaptive sampling to select foreground subnetworks from each modality for parameter updates, thereby rebalancing multi-modal learning. Additionally, we demonstrate the reliability of the AMSS strategy through convergence analysis. Building upon theoretical insights, we further enhance the multi-modal mask subnetwork strategy using unbiased estimation, referred to as AMSS+. Extensive experiments reveal the superiority of our approach over comparison methods.

Learning to Rebalance Multi-Modal Optimization by Adaptively Masking Subnetworks

TL;DR

The paper tackles modality imbalance in multi-modal learning by moving from global, modal-level gradient control to fine-grained, element-wise updates through Adaptively Mask Subnetworks (AMSS). By quantifying modal significance via a variational mutual information rate and focusing updates on selectively sampled subnetworks guided by Fisher information, AMSS rebalances learning across modalities; AMSS+ further introduces an unbiased masking scheme to improve convergence. The authors provide theoretical convergence analyses under biased and unbiased gradient scenarios and demonstrate substantial empirical gains across diverse datasets and backbones, including transformer-based architectures. The approach is presented as a versatile plug-in to existing multi-modal models, improving the utilization of non-dominant modalities with robust performance gains and sound theoretical guarantees.

Abstract

Multi-modal learning aims to enhance performance by unifying models from various modalities but often faces the "modality imbalance" problem in real data, leading to a bias towards dominant modalities and neglecting others, thereby limiting its overall effectiveness. To address this challenge, the core idea is to balance the optimization of each modality to achieve a joint optimum. Existing approaches often employ a modal-level control mechanism for adjusting the update of each modal parameter. However, such a global-wise updating mechanism ignores the different importance of each parameter. Inspired by subnetwork optimization, we explore a uniform sampling-based optimization strategy and find it more effective than global-wise updating. According to the findings, we further propose a novel importance sampling-based, element-wise joint optimization method, called Adaptively Mask Subnetworks Considering Modal Significance(AMSS). Specifically, we incorporate mutual information rates to determine the modal significance and employ non-uniform adaptive sampling to select foreground subnetworks from each modality for parameter updates, thereby rebalancing multi-modal learning. Additionally, we demonstrate the reliability of the AMSS strategy through convergence analysis. Building upon theoretical insights, we further enhance the multi-modal mask subnetwork strategy using unbiased estimation, referred to as AMSS+. Extensive experiments reveal the superiority of our approach over comparison methods.
Paper Structure (26 sections, 4 theorems, 45 equations, 6 figures, 8 tables)

This paper contains 26 sections, 4 theorems, 45 equations, 6 figures, 8 tables.

Key Result

Theorem 1

Under some assumptions for the stochastic gradient $\nabla \ell({\bf w}(t)) \odot {\bf m}(t)$, we have where $\delta \in (0,1)$ and $\nu\ge0$ are two constants.

Figures (6)

  • Figure 1: The illustration of different gradient modulation. Global-wise: During backward propagation, uniform modulation is applied to gradients for all parameters. Element-wise: Forward propagation across the entire network, while in backward propagation, parameter gradients undergo differential modulation through a mask subnetwork.
  • Figure 2: Overall framework of our proposed AMSS strategy, using the Transformer-CNN structure as an example.
  • Figure 3: Experiments with SGD, AdaGrad and Adam optimizers in Kinetics-Sound and Sarcasm-Detection. Baseline means no extra modulation.
  • Figure 4: On the Kinetics-Sound dataset, we employ the concatenation fusion method for the joint training of multi-modal models, encompassing Baseline, OGM-GE, AMSS, and AMSS+. We investigate the changes in training loss and evaluate the variations in test performance across these multi-modal models. Baseline means no gradient modulation strategy.
  • Figure 5: Analysis of Modality Imbalance Problem. Each dataset is represented in Figures from left to right, depicting the variation in model imbalance degree, the comparison between our method and the Baseline model, the performance of single-modal branches in multi-modal trained models, and the performance of single-modal branches with AMSS+, including Audio/Text and Video/Img modalities. The fusion method used in the multi-modal model is Concat.
  • ...and 1 more figures

Theorems & Definitions (6)

  • Theorem 1: Informal, AMSS
  • Theorem 2: Informal, AMSS+
  • Theorem 3: Formal, AMSS
  • Proof 1
  • Theorem 4: Formal, AMSS+
  • Proof 2