Table of Contents
Fetching ...

Multitask Learning Can Improve Worst-Group Outcomes

Atharva Kulkarni, Lucio Dery, Amrith Setlur, Aditi Raghunathan, Ameet Talwalkar, Graham Neubig

TL;DR

This paper investigates how multitask learning (MTL) influences worst-group outcomes when fine-tuning pre-trained models and proposes a regularized MTL approach that combines the end-task objective with a pre-training reconstruction objective, plus capacity control via $\ell_1$ regularization on the shared layer. The authors show that standard MTL is not consistently beneficial for worst-group accuracy, but the regularized variant reliably improves both average and worst-group performance across vision and language tasks, even without group annotations, and often rivals bespoke DRO methods. They validate the method on synthetic data and multiple natural datasets (Waterbirds, MultiNLI, CivilComments), demonstrate the necessity of both regularization and multitasking for gains, and emphasize the importance of pre-training for achieving these improvements. Overall, the work provides a simple, drop-in modification to existing MTL pipelines that can enhance fairness-related metrics without sacrificing aggregate accuracy, offering practical impact for deploying fairer ML systems.

Abstract

In order to create machine learning systems that serve a variety of users well, it is vital to not only achieve high average performance but also ensure equitable outcomes across diverse groups. However, most machine learning methods are designed to improve a model's average performance on a chosen end task without consideration for their impact on worst group error. Multitask learning (MTL) is one such widely used technique. In this paper, we seek not only to understand the impact of MTL on worst-group accuracy but also to explore its potential as a tool to address the challenge of group-wise fairness. We primarily consider the standard setting of fine-tuning a pre-trained model, where, following recent work \citep{gururangan2020don, dery2023aang}, we multitask the end task with the pre-training objective constructed from the end task data itself. In settings with few or no group annotations, we find that multitasking often, but not consistently, achieves better worst-group accuracy than Just-Train-Twice (JTT; \citet{pmlr-v139-liu21f}) -- a representative distributionally robust optimization (DRO) method. Leveraging insights from synthetic data experiments, we propose to modify standard MTL by regularizing the joint multitask representation space. We run a large number of fine-tuning experiments across computer vision and natural language processing datasets and find that our regularized MTL approach \emph{consistently} outperforms JTT on both average and worst-group outcomes. Our official code can be found here: \href{https://github.com/atharvajk98/MTL-group-robustness.git}{\url{https://github.com/atharvajk98/MTL-group-robustness}}.

Multitask Learning Can Improve Worst-Group Outcomes

TL;DR

This paper investigates how multitask learning (MTL) influences worst-group outcomes when fine-tuning pre-trained models and proposes a regularized MTL approach that combines the end-task objective with a pre-training reconstruction objective, plus capacity control via regularization on the shared layer. The authors show that standard MTL is not consistently beneficial for worst-group accuracy, but the regularized variant reliably improves both average and worst-group performance across vision and language tasks, even without group annotations, and often rivals bespoke DRO methods. They validate the method on synthetic data and multiple natural datasets (Waterbirds, MultiNLI, CivilComments), demonstrate the necessity of both regularization and multitasking for gains, and emphasize the importance of pre-training for achieving these improvements. Overall, the work provides a simple, drop-in modification to existing MTL pipelines that can enhance fairness-related metrics without sacrificing aggregate accuracy, offering practical impact for deploying fairer ML systems.

Abstract

In order to create machine learning systems that serve a variety of users well, it is vital to not only achieve high average performance but also ensure equitable outcomes across diverse groups. However, most machine learning methods are designed to improve a model's average performance on a chosen end task without consideration for their impact on worst group error. Multitask learning (MTL) is one such widely used technique. In this paper, we seek not only to understand the impact of MTL on worst-group accuracy but also to explore its potential as a tool to address the challenge of group-wise fairness. We primarily consider the standard setting of fine-tuning a pre-trained model, where, following recent work \citep{gururangan2020don, dery2023aang}, we multitask the end task with the pre-training objective constructed from the end task data itself. In settings with few or no group annotations, we find that multitasking often, but not consistently, achieves better worst-group accuracy than Just-Train-Twice (JTT; \citet{pmlr-v139-liu21f}) -- a representative distributionally robust optimization (DRO) method. Leveraging insights from synthetic data experiments, we propose to modify standard MTL by regularizing the joint multitask representation space. We run a large number of fine-tuning experiments across computer vision and natural language processing datasets and find that our regularized MTL approach \emph{consistently} outperforms JTT on both average and worst-group outcomes. Our official code can be found here: \href{https://github.com/atharvajk98/MTL-group-robustness.git}{\url{https://github.com/atharvajk98/MTL-group-robustness}}.
Paper Structure (31 sections, 15 equations, 6 figures, 8 tables)

This paper contains 31 sections, 15 equations, 6 figures, 8 tables.

Figures (6)

  • Figure 1: Visualization of synthetic training data ($1000$ points).
  • Figure 2: Predictors learned when we train on the end task only. Examples visualized are the balanced test samples created from Equation \ref{['eqn:prim_gen_dist_o']}.
  • Figure 3: The ratio $\log(\frac{\mathbf{a}_{\mathrm{spur}}}{\mathbf{a}_{\mathrm{core}}})$ for 2 (extreme) choices of $|\mathbf{a}|_{1}$ across 4 hyperparameter settings (learning rate $\times$ batch size).
  • Figure 4: Multitask learning architecture used in Section \ref{['subsec:reg_mtl_synth']}. We use a shared intermediate layer and two separate prediction heads for $\mathbf{T}_{\text{aux}}$ and $\mathbf{T}_{\text{end}}$
  • Figure 5: Depicted are the learned half-spaces for the multitask model under $\tau = \{0.1, 10\}$ and $\alpha = 10$. Restricting the capacity of the shared feature space is critical for multitasking to be effective for improving worst group error. Examples visualized are $1000$ balanced test examples sampled from Equation \ref{['eqn:prim_gen_dist_o']}
  • ...and 1 more figures