Table of Contents
Fetching ...

Mitigating Modality Imbalance in Multi-modal Learning via Multi-objective Optimization

Heshan Fernando, Parikshit Ram, Yi Zhou, Soham Dan, Horst Samulowitz, Nathalie Baracaldo, Tianyi Chen

TL;DR

The paper tackles modality imbalance in multi-modal learning by reframing MML as a lexicographic multi-objective optimization problem that prioritizes the worst-performing uni-modal objective. It introduces MIMO, a gradient-based solver that optimizes a smoothly penalized objective combining the multi-modal loss and modality-specific losses, with a smoothing term that approximates the max over modalities. The authors prove convergence guarantees for the proposed method and demonstrate superior performance and up to ~20x speedups on diverse benchmarks compared to existing balanced MML and MOO baselines. This approach enhances generalization by preventing dominance of fast-learning modalities and is adaptable to various multi-modal settings, with code available for replication. The work points to promising extensions to early or hybrid fusion paradigms to broaden applicability.

Abstract

Multi-modal learning (MML) aims to integrate information from multiple modalities, which is expected to lead to superior performance over single-modality learning. However, recent studies have shown that MML can underperform, even compared to single-modality approaches, due to imbalanced learning across modalities. Methods have been proposed to alleviate this imbalance issue using different heuristics, which often lead to computationally intensive subroutines. In this paper, we reformulate the MML problem as a multi-objective optimization (MOO) problem that overcomes the imbalanced learning issue among modalities and propose a gradient-based algorithm to solve the modified MML problem. We provide convergence guarantees for the proposed method, and empirical evaluations on popular MML benchmarks showcasing the improved performance of the proposed method over existing balanced MML and MOO baselines, with up to ~20x reduction in subroutine computation time. Our code is available at https://github.com/heshandevaka/MIMO.

Mitigating Modality Imbalance in Multi-modal Learning via Multi-objective Optimization

TL;DR

The paper tackles modality imbalance in multi-modal learning by reframing MML as a lexicographic multi-objective optimization problem that prioritizes the worst-performing uni-modal objective. It introduces MIMO, a gradient-based solver that optimizes a smoothly penalized objective combining the multi-modal loss and modality-specific losses, with a smoothing term that approximates the max over modalities. The authors prove convergence guarantees for the proposed method and demonstrate superior performance and up to ~20x speedups on diverse benchmarks compared to existing balanced MML and MOO baselines. This approach enhances generalization by preventing dominance of fast-learning modalities and is adaptable to various multi-modal settings, with code available for replication. The work points to promising extensions to early or hybrid fusion paradigms to broaden applicability.

Abstract

Multi-modal learning (MML) aims to integrate information from multiple modalities, which is expected to lead to superior performance over single-modality learning. However, recent studies have shown that MML can underperform, even compared to single-modality approaches, due to imbalanced learning across modalities. Methods have been proposed to alleviate this imbalance issue using different heuristics, which often lead to computationally intensive subroutines. In this paper, we reformulate the MML problem as a multi-objective optimization (MOO) problem that overcomes the imbalanced learning issue among modalities and propose a gradient-based algorithm to solve the modified MML problem. We provide convergence guarantees for the proposed method, and empirical evaluations on popular MML benchmarks showcasing the improved performance of the proposed method over existing balanced MML and MOO baselines, with up to ~20x reduction in subroutine computation time. Our code is available at https://github.com/heshandevaka/MIMO.

Paper Structure

This paper contains 30 sections, 4 theorems, 42 equations, 3 figures, 9 tables, 1 algorithm.

Key Result

Proposition 1

Under Assumptions ass:smooth and ass:lip, there exist $\hat{L}_{mm}>0$ such that $\hat{f}_{mm}$ defined in eq:penalty-smooth-cheb-two is $\hat{L}_{mm}$-smooth (Definition def:smooth), where $\hat{L}_{mm} := L_{mm} + \lambda \sum_{k=1}^2 \left(L_{m_k} + \mu^{-1}L^2_{m_k, 1} \right)$.

Figures (3)

  • Figure 1: Balanced multi-modal learning via multi-objective optimization.(a): Optimizing the standard MML objective can lead to slower convergence, due to fast to learn modalities dominating the optimization process (b): We propose MML-via-MOO (MIMO), which optimizes a modified MML objective. This allows the multi-modal network to avoid dominance by one modality, which leads to faster convergence.
  • Figure 2: Left: Comparison of the training and testing performance of MIMO algorithm with vanilla MML (joint training with sum fusion) on CREMA-D dataset. Middle and Right: Comparison of the loss landscape of vanilla MML and MIMO after 1500 iterations on CREMA-D dataset. The black contours (---) denote the multi-modal training loss, and the yellow dashed contours (- - -) denote the multi-modal testing loss. The red star ( ★) denotes the convergent point of each method. The color of the heatmap denotes the difference between uni-modal training accuracies at the given point of the loss landscape, where blue () denotes audio modality is dominating, green () denotes visual modality is dominating, and higher color intensity denotes larger differences in accuracy. As illustrated by the training curves and loss landscapes, MIMO achieves lower multi-modal test loss (i.e. better generalization) by balancing the learning of each modality.
  • Figure 3: Ablation of hyperparameters.

Theorems & Definitions (9)

  • Remark 1
  • Definition 1: Smoothness and smoothing function lin2024smooth
  • Proposition 1: Smoothness of $\hat{f}_{mm}$
  • Theorem 1: Convergence
  • Remark 2
  • proof
  • Proposition 2
  • proof
  • Theorem 2: shen2023penalty Proposition 2