Table of Contents
Fetching ...

Multimodal Continual Instruction Tuning with Dynamic Gradient Guidance

Songze Li, Mingyu Gao, Tonghua Su, Xu-Yao Zhang, Zhongjie Wang

TL;DR

This paper reframes catastrophic forgetting in multimodal continual instruction tuning (MCIT) as a missing-gradient problem, proposing a dynamic gradient guidance strategy that uses the directional vector from current parameters to previously optimal parameters to approximate old-task gradients. The method combines this gradient guidance, limited replay data, and a Bernoulli-based dynamic update rule to balance stability and plasticity, achieving state-of-the-art performance without model expansion on two MCIT benchmarks. Key contributions include a formal gradient-approximation framework, a scalable update mechanism, and comprehensive ablations across datasets with varying distribution shifts. The results suggest a practical path for robust MCIT with compact architectures, though limitations remain when distribution shifts are large and replay data storage becomes a factor.

Abstract

Multimodal continual instruction tuning enables multimodal large language models to sequentially adapt to new tasks while building upon previously acquired knowledge. However, this continual learning paradigm faces the significant challenge of catastrophic forgetting, where learning new tasks leads to performance degradation on previous ones. In this paper, we introduce a novel insight into catastrophic forgetting by conceptualizing it as a problem of missing gradients from old tasks during new task learning. Our approach approximates these missing gradients by leveraging the geometric properties of the parameter space, specifically using the directional vector between current parameters and previously optimal parameters as gradient guidance. This approximated gradient can be further integrated with real gradients from a limited replay buffer and regulated by a Bernoulli sampling strategy that dynamically balances model stability and plasticity. Extensive experiments on multimodal continual instruction tuning datasets demonstrate that our method achieves state-of-the-art performance without model expansion, effectively mitigating catastrophic forgetting while maintaining a compact architecture.

Multimodal Continual Instruction Tuning with Dynamic Gradient Guidance

TL;DR

This paper reframes catastrophic forgetting in multimodal continual instruction tuning (MCIT) as a missing-gradient problem, proposing a dynamic gradient guidance strategy that uses the directional vector from current parameters to previously optimal parameters to approximate old-task gradients. The method combines this gradient guidance, limited replay data, and a Bernoulli-based dynamic update rule to balance stability and plasticity, achieving state-of-the-art performance without model expansion on two MCIT benchmarks. Key contributions include a formal gradient-approximation framework, a scalable update mechanism, and comprehensive ablations across datasets with varying distribution shifts. The results suggest a practical path for robust MCIT with compact architectures, though limitations remain when distribution shifts are large and replay data storage becomes a factor.

Abstract

Multimodal continual instruction tuning enables multimodal large language models to sequentially adapt to new tasks while building upon previously acquired knowledge. However, this continual learning paradigm faces the significant challenge of catastrophic forgetting, where learning new tasks leads to performance degradation on previous ones. In this paper, we introduce a novel insight into catastrophic forgetting by conceptualizing it as a problem of missing gradients from old tasks during new task learning. Our approach approximates these missing gradients by leveraging the geometric properties of the parameter space, specifically using the directional vector between current parameters and previously optimal parameters as gradient guidance. This approximated gradient can be further integrated with real gradients from a limited replay buffer and regulated by a Bernoulli sampling strategy that dynamically balances model stability and plasticity. Extensive experiments on multimodal continual instruction tuning datasets demonstrate that our method achieves state-of-the-art performance without model expansion, effectively mitigating catastrophic forgetting while maintaining a compact architecture.

Paper Structure

This paper contains 18 sections, 18 equations, 8 figures, 4 tables, 1 algorithm.

Figures (8)

  • Figure 1: Illustration of our novel insight into catastrophic forgetting. We attribute catastrophic forgetting to the absence of old tasks' gradients during new task learning, which prevents gradient descent from converging to the optimal parameters achievable through joint training of all tasks. To address this problem, we approximate the missing gradients of old tasks by utilizing the optimal parameters from previous tasks (red star) as directional guides. The vector connecting current model parameters to these previously optimal parameters provides geometric guidance for approximating old task gradient directions. By integrating this approximated gradient with the new task gradient, we effectively simulate the joint training gradient, thereby alleviating catastrophic forgetting.
  • Figure 1: Illustration of VQAv2 dataset.
  • Figure 2: Optimization process with different memory retention strategy. (a) Learning a new task without any memory retention tricks. Due to the exclusive presence of new task gradients (yellow arrow) and the absence of old task gradients, the model converges directly to the optimal solution for the new task, resulting in complete forgetting of previous knowledge. (b) Learning a new task with replay data. The inclusion of a limited number of replay samples provides partial gradient information from old tasks (blue arrow), enabling the model to converge to parameters that retain some memory. However, the gradients from these samples cannot represent the expected gradient over the entire old task dataset throughout the optimization process, leading to suboptimal convergence relative to multi-task learning and residual catastrophic forgetting. (c) Learning a new task with our dynamical gradient guidance. Our method approximates old task gradients by leveraging optimal parameters from previous tasks as directional guides (blue arrow), fused with real gradients from cached replay samples (purple arrow, combined with new task gradient). This approximation is dynamically regulated through Bernoulli sampling (red dotted line) to control gradient update frequency, achieving balanced convergence towards joint task optimization.
  • Figure 2: Illustration of UCIT dataset.
  • Figure 3: Result of ablation on gradient approximation. We conduct this ablation under two configurations: using only replay buffers without $\hat{g}$ ($\mathcal{M}$-only) and using only $\hat{g}$ without replay buffers ($\hat{g}$-only). Full represents the full version of our method which integrates both of $\hat{g}$ and replay buffer $\mathcal{M}$.
  • ...and 3 more figures