Table of Contents
Fetching ...

DMTG: One-Shot Differentiable Multi-Task Grouping

Yuan Gao, Shuguo Jiang, Moran Li, Jin-Gang Yu, Gui-Song Xia

TL;DR

This work tackles the scalability challenge of Multi-Task Learning with many tasks by proposing a one-shot Differentiable Multi-Task Grouping (DMTG) framework. It formulates MTG as a differentiable pruning problem where a Categorical distribution assigns tasks to up to $K$ encoder groups, starting with $K$ branches each connected to all $N$ task heads and pruning to $N$ so every task belongs to a single group, enabling joint optimization of group identification and grouped task learning via a differentiable, end-to-end process. The approach leverages high-order task affinities, achieves $O(K)$ encoder training complexity, and demonstrates superior results on Taskonomy-5 and CelebA-9 across multiple backbones, with ablations validating one-shot benefits, transformer's compatibility, and flexible sharing of encoder layers. These findings indicate that integrating architecture search (via differentiable pruning) with grouped task learning yields both efficiency and accuracy advantages for large-scale MTG, with practical impact on real-world multi-task systems.

Abstract

We aim to address Multi-Task Learning (MTL) with a large number of tasks by Multi-Task Grouping (MTG). Given N tasks, we propose to simultaneously identify the best task groups from 2^N candidates and train the model weights simultaneously in one-shot, with the high-order task-affinity fully exploited. This is distinct from the pioneering methods which sequentially identify the groups and train the model weights, where the group identification often relies on heuristics. As a result, our method not only improves the training efficiency, but also mitigates the objective bias introduced by the sequential procedures that potentially lead to a suboptimal solution. Specifically, we formulate MTG as a fully differentiable pruning problem on an adaptive network architecture determined by an underlying Categorical distribution. To categorize N tasks into K groups (represented by K encoder branches), we initially set up KN task heads, where each branch connects to all N task heads to exploit the high-order task-affinity. Then, we gradually prune the KN heads down to N by learning a relaxed differentiable Categorical distribution, ensuring that each task is exclusively and uniquely categorized into only one branch. Extensive experiments on CelebA and Taskonomy datasets with detailed ablations show the promising performance and efficiency of our method. The codes are available at https://github.com/ethanygao/DMTG.

DMTG: One-Shot Differentiable Multi-Task Grouping

TL;DR

This work tackles the scalability challenge of Multi-Task Learning with many tasks by proposing a one-shot Differentiable Multi-Task Grouping (DMTG) framework. It formulates MTG as a differentiable pruning problem where a Categorical distribution assigns tasks to up to encoder groups, starting with branches each connected to all task heads and pruning to so every task belongs to a single group, enabling joint optimization of group identification and grouped task learning via a differentiable, end-to-end process. The approach leverages high-order task affinities, achieves encoder training complexity, and demonstrates superior results on Taskonomy-5 and CelebA-9 across multiple backbones, with ablations validating one-shot benefits, transformer's compatibility, and flexible sharing of encoder layers. These findings indicate that integrating architecture search (via differentiable pruning) with grouped task learning yields both efficiency and accuracy advantages for large-scale MTG, with practical impact on real-world multi-task systems.

Abstract

We aim to address Multi-Task Learning (MTL) with a large number of tasks by Multi-Task Grouping (MTG). Given N tasks, we propose to simultaneously identify the best task groups from 2^N candidates and train the model weights simultaneously in one-shot, with the high-order task-affinity fully exploited. This is distinct from the pioneering methods which sequentially identify the groups and train the model weights, where the group identification often relies on heuristics. As a result, our method not only improves the training efficiency, but also mitigates the objective bias introduced by the sequential procedures that potentially lead to a suboptimal solution. Specifically, we formulate MTG as a fully differentiable pruning problem on an adaptive network architecture determined by an underlying Categorical distribution. To categorize N tasks into K groups (represented by K encoder branches), we initially set up KN task heads, where each branch connects to all N task heads to exploit the high-order task-affinity. Then, we gradually prune the KN heads down to N by learning a relaxed differentiable Categorical distribution, ensuring that each task is exclusively and uniquely categorized into only one branch. Extensive experiments on CelebA and Taskonomy datasets with detailed ablations show the promising performance and efficiency of our method. The codes are available at https://github.com/ethanygao/DMTG.
Paper Structure (26 sections, 8 equations, 4 figures, 11 tables)

This paper contains 26 sections, 8 equations, 4 figures, 11 tables.

Figures (4)

  • Figure 1: We formulate the Multi-Task Grouping (MTG) problem as network pruning. This figure illustrates the categorization of 4 tasks into 3 groups, where each branch represents a task group. As shown in the Upper Subfigure, at initialization, each group connects to all the task heads, ensuring full exploration of high-order task-affinity. Throughout MTG training, we simultaneously prune the task heads and train the weights of the group-specific branches. Our training process ensures that MTG converges to a categorization where each task exclusively and uniquely belongs to only one group, as illustrated in the Lower Subfigure.
  • Figure 2: The overview of our method. We formulate the Multi-Task Grouping (MTG) problem as network pruning, where our method consists of a grouped task learning module and a group identification module. In order to categorize $N$ tasks into $K$ groups, our network is constructed with $K$ group-specific branches, optionally with shared lower layers. At initialization, we connect each branch to all the task heads (enabling them to predict all tasks), so that the high-order task-affinity can be exploited. We then formulate the grouped task learning as the model weights training for each group-specific branch, and the group identification as the network head pruning. The final grouped task losses are generated by the element-wise product of both modules, which in turn ensures both modules to be trained simultaneously in one-shot with the high-order task-affinity fully exploited. This figure illustrates categorizing 3 tasks into 2 groups.
  • Figure A1: Illustration of the CelebA and Taskonomy datasets. Several image samples across different genders and races of CelebA dataset are shown in subfigure \ref{['img:celeba']}, while subfigure \ref{['img:taskonomy']} exhibits an indoor RGB image with 5 annotated labels used in our experiments.
  • Figure A2: Normalized gain w.r.t. classification errors, NormGain$_E$, of our method in terms of Naive MTL for each task on the CelebA dataset with $N=40$ tasks and $K=5$ groups. Other parameters are identical to those in Table \ref{['tab:celeba']}.