Table of Contents
Fetching ...

Proactive Gradient Conflict Mitigation in Multi-Task Learning: A Sparse Training Perspective

Zhi Zhang, Jiayi Shen, Congfeng Cao, Gaole Dai, Shiji Zhou, Qizhe Zhang, Shanghang Zhang, Ekaterina Shutova

TL;DR

This paper systematically investigates the occurrence of gradient conflict across different methods and proposes a strategy to reduce such conflicts through sparse training (ST), wherein only a portion of the model's parameters are updated during training while keeping the rest unchanged.

Abstract

Advancing towards generalist agents necessitates the concurrent processing of multiple tasks using a unified model, thereby underscoring the growing significance of simultaneous model training on multiple downstream tasks. A common issue in multi-task learning is the occurrence of gradient conflict, which leads to potential competition among different tasks during joint training. This competition often results in improvements in one task at the expense of deterioration in another. Although several optimization methods have been developed to address this issue by manipulating task gradients for better task balancing, they cannot decrease the incidence of gradient conflict. In this paper, we systematically investigate the occurrence of gradient conflict across different methods and propose a strategy to reduce such conflicts through sparse training (ST), wherein only a portion of the model's parameters are updated during training while keeping the rest unchanged. Our extensive experiments demonstrate that ST effectively mitigates conflicting gradients and leads to superior performance. Furthermore, ST can be easily integrated with gradient manipulation techniques, thus enhancing their effectiveness.

Proactive Gradient Conflict Mitigation in Multi-Task Learning: A Sparse Training Perspective

TL;DR

This paper systematically investigates the occurrence of gradient conflict across different methods and proposes a strategy to reduce such conflicts through sparse training (ST), wherein only a portion of the model's parameters are updated during training while keeping the rest unchanged.

Abstract

Advancing towards generalist agents necessitates the concurrent processing of multiple tasks using a unified model, thereby underscoring the growing significance of simultaneous model training on multiple downstream tasks. A common issue in multi-task learning is the occurrence of gradient conflict, which leads to potential competition among different tasks during joint training. This competition often results in improvements in one task at the expense of deterioration in another. Although several optimization methods have been developed to address this issue by manipulating task gradients for better task balancing, they cannot decrease the incidence of gradient conflict. In this paper, we systematically investigate the occurrence of gradient conflict across different methods and propose a strategy to reduce such conflicts through sparse training (ST), wherein only a portion of the model's parameters are updated during training while keeping the rest unchanged. Our extensive experiments demonstrate that ST effectively mitigates conflicting gradients and leads to superior performance. Furthermore, ST can be easily integrated with gradient manipulation techniques, thus enhancing their effectiveness.

Paper Structure

This paper contains 55 sections, 13 equations, 13 figures, 17 tables.

Figures (13)

  • Figure 1: The average occurrence percentage of gradient conflict over epochs (all epochs/last 50% epochs) during training on the SAM model with NYUv2 datasets is evaluated using various methods, including joint training and gradient manipulation techniques.
  • Figure 2: Visualization of gradients change for different methods. $g_i$ and $g_j$ are two conflicting gradients, and the green arrow is the actual update vector. The process of sparse training can be interpreted as performing an orthographic/coordinate projection of conflicting gradients onto the subspace defined by the selected parameters, resulting in better alignment of the projected gradients.
  • Figure 3: PSN. Top-1 highest-magnitude parameter among all input connections of each neuron is selected.
  • Figure 4: The incidence of GC between tasks during training SAM on NYUv2 dataset. The top and bottom figures are Joint Train and PCGrad respectively. Please see \ref{['fig:number of gc']} in \ref{['sec: supp NYU-v2 on SAM']} for more results on other gradient manipulation methods.
  • Figure 5: Ablation study for Joint Train with NYU-v2 dataset. (a) The average incidence of GC during joint training on different sizes of Swin transformers. Please see the numerical statics for all epochs in \ref{['tab:gc of different swin model size']} in \ref{['sec: supp NYU-v2 on Swin']}. (b) The different number of trainable parameters for MTAN and SAM models. (C) Different sparse methods training on SAM. Metrics for all tasks are min-max normalized. Please see \ref{['tab:ablation (full): different sparse training methods']} for detailed results in \ref{['sec: supp: Ablation study']}.
  • ...and 8 more figures

Theorems & Definitions (2)

  • Definition 1: Gradient Conflict
  • Definition 2: Sparse Training