Table of Contents
Fetching ...

Scalable Multitask Learning Using Gradient-based Estimation of Task Affinity

Dongyue Li, Aneesh Sharma, Hongyang R. Zhang

TL;DR

This work presents a new algorithm Grad-TAG, to train a "base" model for all tasks and then use a linearization technique to estimate the loss of any other model with a specific task combination, and designs a semi-definite program for clustering to group similar tasks that maximize the average density of clusters.

Abstract

Multitask learning is a widely used paradigm for training models on diverse tasks, with applications ranging from graph neural networks to language model fine-tuning. Since tasks may interfere with each other, a key notion for modeling their relationships is task affinity. This includes pairwise task affinity, computed among pairs of tasks, and higher-order affinity, computed among subsets of tasks. Naively computing either of them requires repeatedly training on data from various task combinations, which is computationally intensive. We present a new algorithm Grad-TAG that can estimate task affinities without this repeated training. The key idea of Grad-TAG is to train a "base" model for all tasks and then use a linearization technique to estimate the loss of the model for a specific task combination. The linearization works by computing a gradient-based approximation of the loss, using low-dimensional projections of gradients as features in a logistic regression to predict labels for the task combination. We show that the linearized model can provably approximate the loss when the gradient-based approximation is accurate, and also empirically verify that on several large models. Then, given the estimated task affinity, we design a semi-definite program for clustering similar tasks by maximizing the average density of clusters. We evaluate Grad-TAG's performance across seven datasets, including multi-label classification on graphs, and instruction fine-tuning of language models. Our task affinity estimates are within 2.7% distance to the true affinities while needing only 3% of FLOPs in full training. On our largest graph with 21M edges and 500 labeling tasks, our algorithm delivers estimates within 5% distance to the true affinities, using only 112 GPU hours. Our results show that Grad-TAG achieves excellent performance and runtime tradeoffs compared to existing approaches.

Scalable Multitask Learning Using Gradient-based Estimation of Task Affinity

TL;DR

This work presents a new algorithm Grad-TAG, to train a "base" model for all tasks and then use a linearization technique to estimate the loss of any other model with a specific task combination, and designs a semi-definite program for clustering to group similar tasks that maximize the average density of clusters.

Abstract

Multitask learning is a widely used paradigm for training models on diverse tasks, with applications ranging from graph neural networks to language model fine-tuning. Since tasks may interfere with each other, a key notion for modeling their relationships is task affinity. This includes pairwise task affinity, computed among pairs of tasks, and higher-order affinity, computed among subsets of tasks. Naively computing either of them requires repeatedly training on data from various task combinations, which is computationally intensive. We present a new algorithm Grad-TAG that can estimate task affinities without this repeated training. The key idea of Grad-TAG is to train a "base" model for all tasks and then use a linearization technique to estimate the loss of the model for a specific task combination. The linearization works by computing a gradient-based approximation of the loss, using low-dimensional projections of gradients as features in a logistic regression to predict labels for the task combination. We show that the linearized model can provably approximate the loss when the gradient-based approximation is accurate, and also empirically verify that on several large models. Then, given the estimated task affinity, we design a semi-definite program for clustering similar tasks by maximizing the average density of clusters. We evaluate Grad-TAG's performance across seven datasets, including multi-label classification on graphs, and instruction fine-tuning of language models. Our task affinity estimates are within 2.7% distance to the true affinities while needing only 3% of FLOPs in full training. On our largest graph with 21M edges and 500 labeling tasks, our algorithm delivers estimates within 5% distance to the true affinities, using only 112 GPU hours. Our results show that Grad-TAG achieves excellent performance and runtime tradeoffs compared to existing approaches.
Paper Structure (34 sections, 1 theorem, 15 equations, 6 figures, 4 tables, 2 algorithms)

This paper contains 34 sections, 1 theorem, 15 equations, 6 figures, 4 tables, 2 algorithms.

Key Result

Proposition 3.3

Let $\mathcal{D}$ be a search space whose radius is at most $D$. Suppose the gradient of $f_{\theta^{\star}}$ at the initialization $\theta^{\star}$ in the training set is at most $G$ in Euclidean norm. For each task $i = 1, 2, \dots, n$, let $T_i$ denote the training data. Suppose that for every $i Provided that $d = O(\frac{\log p}{\epsilon^2})$, the training loss of $\hat{W}_S$ is bounded away

Figures (6)

  • Figure 1: Visualization of the gradient-based model approximation step in our Grad-TAE algorithm, where we replace multitask training with a regression-based estimation of model parameters fine-tuned on a particular subset of tasks.
  • Figure 2: SDP relaxation
  • Figure 3: Spectral/Lloyd's clustering
  • Figure 5: The number of GPU hours vs. the number of tasks to compute pairwise affinity, evaluated on the Orkut graph up to 500 tasks. We estimate the full training cost by training on randomly sampled 2000 subsets of tasks.
  • Figure 6: This figure illustrates the tradeoff between error rate and computation cost, measured by the number of FLOPs and GPU hours. Compared to multitask learning baselines, our approach achieves the Pareto optimal balance between error rate and computation cost. Recall that $M$ is the number of meta-initializations used in Grad-TAG. The number of FLOPs is reported in the Giga FLOPs unit. For both settings, there are $n=100$ tasks. Our approach delivers comparable test accuracy to all baselines, using 32.8$\times$ fewer FLOPs and 5.2$\times$ less GPU hours than all baselines.
  • ...and 1 more figures

Theorems & Definitions (9)

  • Definition 2.1: Multitask learning algorithms
  • Example 2.2: Pairwise task affinity
  • Example 2.3: High-order task affinity
  • Remark 3.1: Second-order approximation
  • Remark 3.2: Extension to multiple classification or regression
  • Proposition 3.3
  • Example 4.1: Discussion about alternative clustering algorithms
  • Remark 4.2: Approximation ratio of the SDP relaxation
  • Remark 4.3: Further variants of Grad-TAG