Table of Contents
Fetching ...

Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning

Idan Achituve, Idit Diamant, Arnon Netzer, Gal Chechik, Ethan Fetaya

TL;DR

BayesAgg-MTL tackles the problem of gradient aggregation in multi-task learning by introducing a Bayesian treatment of the last-layer task parameters, which induces distributions over task gradients. By approximating these gradient distributions as Gaussians and performing moment matching, the method derives per-dimension, per-task weights to form an uncertainty-aware shared update. The paper provides concrete regression and classification procedures, including a diagonal-covariance approximation and Monte Carlo moments for classification, and demonstrates state-of-the-art performance on benchmarks such as QM9, CIFAR-100, ChestX-ray14, and UTKFace. This approach advances practical MTL by incorporating gradient uncertainty into optimization, offering improved updates and calibrated uncertainty estimates with publicly available code.

Abstract

As machine learning becomes more prominent there is a growing demand to perform several inference tasks in parallel. Running a dedicated model for each task is computationally expensive and therefore there is a great interest in multi-task learning (MTL). MTL aims at learning a single model that solves several tasks efficiently. Optimizing MTL models is often achieved by computing a single gradient per task and aggregating them for obtaining a combined update direction. However, these approaches do not consider an important aspect, the sensitivity in the gradient dimensions. Here, we introduce a novel gradient aggregation approach using Bayesian inference. We place a probability distribution over the task-specific parameters, which in turn induce a distribution over the gradients of the tasks. This additional valuable information allows us to quantify the uncertainty in each of the gradients dimensions, which can then be factored in when aggregating them. We empirically demonstrate the benefits of our approach in a variety of datasets, achieving state-of-the-art performance.

Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning

TL;DR

BayesAgg-MTL tackles the problem of gradient aggregation in multi-task learning by introducing a Bayesian treatment of the last-layer task parameters, which induces distributions over task gradients. By approximating these gradient distributions as Gaussians and performing moment matching, the method derives per-dimension, per-task weights to form an uncertainty-aware shared update. The paper provides concrete regression and classification procedures, including a diagonal-covariance approximation and Monte Carlo moments for classification, and demonstrates state-of-the-art performance on benchmarks such as QM9, CIFAR-100, ChestX-ray14, and UTKFace. This approach advances practical MTL by incorporating gradient uncertainty into optimization, offering improved updates and calibrated uncertainty estimates with publicly available code.

Abstract

As machine learning becomes more prominent there is a growing demand to perform several inference tasks in parallel. Running a dedicated model for each task is computationally expensive and therefore there is a great interest in multi-task learning (MTL). MTL aims at learning a single model that solves several tasks efficiently. Optimizing MTL models is often achieved by computing a single gradient per task and aggregating them for obtaining a combined update direction. However, these approaches do not consider an important aspect, the sensitivity in the gradient dimensions. Here, we introduce a novel gradient aggregation approach using Bayesian inference. We place a probability distribution over the task-specific parameters, which in turn induce a distribution over the gradients of the tasks. This additional valuable information allows us to quantify the uncertainty in each of the gradients dimensions, which can then be factored in when aggregating them. We empirically demonstrate the benefits of our approach in a variety of datasets, achieving state-of-the-art performance.
Paper Structure (22 sections, 16 equations, 4 figures, 7 tables, 1 algorithm)

This paper contains 22 sections, 16 equations, 4 figures, 7 tables, 1 algorithm.

Figures (4)

  • Figure 1: BayesAgg-MTL assumes a probability distribution over the last layer parameters of each task. It first maps these distributions to the space of the last shared representation. Then an update direction is found for the shared representation based on the mean and variance of all distributions (denoted by X).
  • Figure 2: BayesAgg-MTL update for a two-dimensional feature representation. Black arrows indicate the mean update direction of each task; Red arrow is the update direction of a simple average; Blue arrow is the proposed update direction. Darker colors in the contours represent regions with higher density.
  • Figure 3: Mean weight over dimensions per-example for $20$ random examples on the QM9 dataset at different training stages.
  • Figure 4: Expected calibration error (ECE) vs Brier score for the Gender and Ethnicity tasks from the UTKFace dataset. In orange - baseline methods, and in purple our method. Lower values are better. We named our method and the top competitor on each plot.