Table of Contents
Fetching ...

STG-MTL: Scalable Task Grouping for Multi-Task Learning Using Data Map

Ammar Sherif, Abubakar Abid, Mustafa Elattar, Mohamed ElHelw

TL;DR

MTL suffers from an exponential number of possible task groupings and negative transfer between tasks, limiting scalability. The authors propose STG-MTL, a data-driven framework that uses Data Maps to capture per-task training dynamics, clusters tasks with soft memberships, and applies loss weighting to realize specialized, cluster-focused models. The approach scales to large task sets (demonstrated up to 100 tasks on CIFAR100) and yields clustering that aligns with dataset groupings, while maintaining a modular, model-agnostic implementation. Overall, STG-MTL offers a practical path to scalable, high-performance MTL across domains with many related tasks.

Abstract

Multi-Task Learning (MTL) is a powerful technique that has gained popularity due to its performance improvement over traditional Single-Task Learning (STL). However, MTL is often challenging because there is an exponential number of possible task groupings, which can make it difficult to choose the best one because some groupings might produce performance degradation due to negative interference between tasks. That is why existing solutions are severely suffering from scalability issues, limiting any practical application. In our paper, we propose a new data-driven method that addresses these challenges and provides a scalable and modular solution for classification task grouping based on a re-proposed data-driven features, Data Maps, which capture the training dynamics for each classification task during the MTL training. Through a theoretical comparison with other techniques, we manage to show that our approach has the superior scalability. Our experiments show a better performance and verify the method's effectiveness, even on an unprecedented number of tasks (up to 100 tasks on CIFAR100). Being the first to work on such number of tasks, our comparisons on the resulting grouping shows similar grouping to the mentioned in the dataset, CIFAR100. Finally, we provide a modular implementation for easier integration and testing, with examples from multiple datasets and tasks.

STG-MTL: Scalable Task Grouping for Multi-Task Learning Using Data Map

TL;DR

MTL suffers from an exponential number of possible task groupings and negative transfer between tasks, limiting scalability. The authors propose STG-MTL, a data-driven framework that uses Data Maps to capture per-task training dynamics, clusters tasks with soft memberships, and applies loss weighting to realize specialized, cluster-focused models. The approach scales to large task sets (demonstrated up to 100 tasks on CIFAR100) and yields clustering that aligns with dataset groupings, while maintaining a modular, model-agnostic implementation. Overall, STG-MTL offers a practical path to scalable, high-performance MTL across domains with many related tasks.

Abstract

Multi-Task Learning (MTL) is a powerful technique that has gained popularity due to its performance improvement over traditional Single-Task Learning (STL). However, MTL is often challenging because there is an exponential number of possible task groupings, which can make it difficult to choose the best one because some groupings might produce performance degradation due to negative interference between tasks. That is why existing solutions are severely suffering from scalability issues, limiting any practical application. In our paper, we propose a new data-driven method that addresses these challenges and provides a scalable and modular solution for classification task grouping based on a re-proposed data-driven features, Data Maps, which capture the training dynamics for each classification task during the MTL training. Through a theoretical comparison with other techniques, we manage to show that our approach has the superior scalability. Our experiments show a better performance and verify the method's effectiveness, even on an unprecedented number of tasks (up to 100 tasks on CIFAR100). Being the first to work on such number of tasks, our comparisons on the resulting grouping shows similar grouping to the mentioned in the dataset, CIFAR100. Finally, we provide a modular implementation for easier integration and testing, with examples from multiple datasets and tasks.
Paper Structure (17 sections, 2 equations, 106 figures, 2 tables)

This paper contains 17 sections, 2 equations, 106 figures, 2 tables.

Figures (106)

  • Figure 1: Overview of our method to cluster the tasks using Data Maps. $(1)$ we use a single Multi-head Multi-Task Learning architecture to jointly train all the tasks. Each head is task-specific layers. $(2)$ we extract the data maps of all the tasks across the epochs in $E$. $(3)$ we use the data maps to cluster the tasks using kmeans and generate the memberships according to Equation \ref{['eq_fuzzification']}. $(4)$ to evaluate our clustering results, we train $m$ models where each model represents a cluster focusing on particular tasks using the memberships as loss weights.
  • Figure 2: An example of a generated data map for the "Living being" task after 21 epochs of co-training on 15 tasks of G2 (Section \ref{['section_datasets_tasks']})
  • Figure 3: The procedure to use our specialized trained models to infer the results
  • Figure 4: G2: $2$ clusters
  • Figure 5: G2: $3$ clusters
  • ...and 101 more figures