Table of Contents
Fetching ...

Conflict-Averse Gradient Descent for Multi-task Learning

Bo Liu, Xingchao Liu, Xiaojie Jin, Peter Stone, Qiang Liu

TL;DR

This work addresses gradient conflict in multi-task learning by introducing Conflict-Averse Gradient Descent (CAGrad), which seeks updates within a neighborhood of the average gradient that maximize the worst local improvement across tasks. The method is formulated via a dual optimization over task weights, with a convergence guarantee for 0 ≤ c < 1 that ensures descent of the main objective L0 while balancing task objectives. CAGrad generalizes both vanilla gradient descent and MGDA, and empirical results across supervised, reinforcement learning, and semi-supervised settings show improved performance over prior gradient-manipulation methods. The work demonstrates practical effectiveness and provides a theoretical foundation for convergence to a minimum of the average loss, offering a principled alternative to Pareto-only guarantees in multi-task optimization.

Abstract

The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.

Conflict-Averse Gradient Descent for Multi-task Learning

TL;DR

This work addresses gradient conflict in multi-task learning by introducing Conflict-Averse Gradient Descent (CAGrad), which seeks updates within a neighborhood of the average gradient that maximize the worst local improvement across tasks. The method is formulated via a dual optimization over task weights, with a convergence guarantee for 0 ≤ c < 1 that ensures descent of the main objective L0 while balancing task objectives. CAGrad generalizes both vanilla gradient descent and MGDA, and empirical results across supervised, reinforcement learning, and semi-supervised settings show improved performance over prior gradient-manipulation methods. The work demonstrates practical effectiveness and provides a theoretical foundation for convergence to a minimum of the average loss, offering a principled alternative to Pareto-only guarantees in multi-task optimization.

Abstract

The goal of multi-task learning is to enable more efficient learning than single task learning by sharing model structures for a diverse set of tasks. A standard multi-task learning objective is to minimize the average loss across all tasks. While straightforward, using this objective often results in much worse final performance for each task than learning them independently. A major challenge in optimizing a multi-task model is the conflicting gradients, where gradients of different task objectives are not well aligned so that following the average gradient direction can be detrimental to specific tasks' performance. Previous work has proposed several heuristics to manipulate the task gradients for mitigating this problem. But most of them lack convergence guarantee and/or could converge to any Pareto-stationary point. In this paper, we introduce Conflict-Averse Gradient descent (CAGrad) which minimizes the average loss function, while leveraging the worst local improvement of individual tasks to regularize the algorithm trajectory. CAGrad balances the objectives automatically and still provably converges to a minimum over the average loss. It includes the regular gradient descent (GD) and the multiple gradient descent algorithm (MGDA) in the multi-objective optimization (MOO) literature as special cases. On a series of challenging multi-task supervised learning and reinforcement learning tasks, CAGrad achieves improved performance over prior state-of-the-art multi-objective gradient manipulation methods.

Paper Structure

This paper contains 34 sections, 5 theorems, 34 equations, 6 figures, 8 tables.

Key Result

Theorem 3.2

Assume the individual loss functions $L_0, L_1,\ldots, L_K$ are differentiable on $\mathbb{R}^m$ and their gradients $\nabla L_i(\theta)$ are all $H$-Lipschitz, i.e. $\left\lVert\nabla L_i(x) - \nabla L_i(y)\right\rVert \leq H\left\lVert x-y\right\rVert$ for $i=0,1,\ldots, K$ where $0 \leq H \leq \i

Figures (6)

  • Figure 1: The optimization challenges faced by gradient descent (GD) and existing gradient manipulation methods like the multiple gradient descent algorithm (MGDA) desideri2012multiple and PCGrad yu2020gradient. MGDA, PCGrad and CAGrad are applied on top of the Adam optimizer kingma2014adam. For each methods, we repeat 3 runs of optimization from different initial points (labeled with $\bullet$). Each optimization trajectory is colored from red to yellow. GD with Adam gets stuck on two of the initial points because the gradient of one task overshadows that of the other task, causing the algorithm to jump back and forth between the walls of a steep valley without making progress along the floor of the valley. MGDA and PCGrad stop optimization as soon as they reach the Pareto set.
  • Figure 2: The combined update vector $d$ (in red) of a two-task learning problem with gradient descent (GD), multiple gradient descent algorithm (MGDA), PCGrad and Conflict-Averse Gradient descent (CAGrad). The two task-specific gradients are labeled $g_1$ and $g_2$. MGDA's objective is given in its primal form (See Appendix \ref{['apx:mgda']}). For PCGrad, each gradient is first projected onto the normal plane of the other (the dashed arrows). Then the final update vector is the average of the two projected gradients. CAGrad finds the best update vector within a ball around the average gradient that maximizes the worse local improvement between task 1 and task 2.
  • Figure 3: The left four plots are 5 runs of each algorithms from 5 different initial parameter vectors, where trajectories are colored from red to yellow. The right two plots are CAGrad's results with a varying $c \in \{0,0.2,0.5,0.8,10\}$.
  • Figure 4: The average and individual training losses on the Fashion-and-MNIST benchmark by running GD, MGDA, PCGrad and CAGrad with different $c$ values. GD gets stuck at the steep valley (the area with a cloud of dots), which other methods can pass. MGDA and PCGrad converge randomly on the Pareto set.
  • Figure 5: Test loss and training time comparison on NYU-v2 and Cityscapes.
  • ...and 1 more figures

Theorems & Definitions (9)

  • Definition 3.1: Pareto optimal and stationary points
  • Theorem 3.2: Convergence of CAGrad
  • Theorem A.1: Convergence of PCGrad yu2020gradient
  • Lemma A.2
  • proof
  • Theorem A.4: Convergence of CAGrad
  • proof
  • Theorem A.5
  • proof