Table of Contents
Fetching ...

Inductive biases of multi-task learning and finetuning: multiple regimes of feature reuse

Samuel Lippl, Jack W. Lindsey

TL;DR

Novel implicit regularization penalties associated with MTL and PT+FT in diagonal linear networks and single-hidden-layer ReLU networks are described and it is found that weight rescaling improves performance when it causes models to display signatures of nested feature selection.

Abstract

Neural networks are often trained on multiple tasks, either simultaneously (multi-task learning, MTL) or sequentially (pretraining and subsequent finetuning, PT+FT). In particular, it is common practice to pretrain neural networks on a large auxiliary task before finetuning on a downstream task with fewer samples. Despite the prevalence of this approach, the inductive biases that arise from learning multiple tasks are poorly characterized. In this work, we address this gap. We describe novel implicit regularization penalties associated with MTL and PT+FT in diagonal linear networks and single-hidden-layer ReLU networks. These penalties indicate that MTL and PT+FT induce the network to reuse features in different ways. 1) Both MTL and PT+FT exhibit biases towards feature reuse between tasks, and towards sparsity in the set of learned features. We show a "conservation law" that implies a direct tradeoff between these two biases. 2) PT+FT exhibits a novel "nested feature selection" regime, not described by either the "lazy" or "rich" regimes identified in prior work, which biases it to rely on a sparse subset of the features learned during pretraining. This regime is much narrower for MTL. 3) PT+FT (but not MTL) in ReLU networks benefits from features that are correlated between the auxiliary and main task. We confirm these findings empirically with teacher-student models, and introduce a technique -- weight rescaling following pretraining -- that can elicit the nested feature selection regime. Finally, we validate our theory in deep neural networks trained on image classification. We find that weight rescaling improves performance when it causes models to display signatures of nested feature selection. Our results suggest that nested feature selection may be an important inductive bias for finetuning neural networks.

Inductive biases of multi-task learning and finetuning: multiple regimes of feature reuse

TL;DR

Novel implicit regularization penalties associated with MTL and PT+FT in diagonal linear networks and single-hidden-layer ReLU networks are described and it is found that weight rescaling improves performance when it causes models to display signatures of nested feature selection.

Abstract

Neural networks are often trained on multiple tasks, either simultaneously (multi-task learning, MTL) or sequentially (pretraining and subsequent finetuning, PT+FT). In particular, it is common practice to pretrain neural networks on a large auxiliary task before finetuning on a downstream task with fewer samples. Despite the prevalence of this approach, the inductive biases that arise from learning multiple tasks are poorly characterized. In this work, we address this gap. We describe novel implicit regularization penalties associated with MTL and PT+FT in diagonal linear networks and single-hidden-layer ReLU networks. These penalties indicate that MTL and PT+FT induce the network to reuse features in different ways. 1) Both MTL and PT+FT exhibit biases towards feature reuse between tasks, and towards sparsity in the set of learned features. We show a "conservation law" that implies a direct tradeoff between these two biases. 2) PT+FT exhibits a novel "nested feature selection" regime, not described by either the "lazy" or "rich" regimes identified in prior work, which biases it to rely on a sparse subset of the features learned during pretraining. This regime is much narrower for MTL. 3) PT+FT (but not MTL) in ReLU networks benefits from features that are correlated between the auxiliary and main task. We confirm these findings empirically with teacher-student models, and introduce a technique -- weight rescaling following pretraining -- that can elicit the nested feature selection regime. Finally, we validate our theory in deep neural networks trained on image classification. We find that weight rescaling improves performance when it causes models to display signatures of nested feature selection. Our results suggest that nested feature selection may be an important inductive bias for finetuning neural networks.
Paper Structure (35 sections, 7 theorems, 40 equations, 11 figures)

This paper contains 35 sections, 7 theorems, 40 equations, 11 figures.

Key Result

Corollary 1

For the multi-output diagonal linear network defined in Eq. eq:diaglinmtl, a solution $\beta^*$ with minimal parameter norm $\|w\|_2^2+\|v\|_2^2$ subject to the constraint that it fits the training data ($X^{main} \vec{\beta}^{main} = \vec{y}^{main}, X^{aux} \vec{\beta}^{aux} = \vec{y}^{aux}$) also

Figures (11)

  • Figure 1: Theoretically derived regularization penalties. a, Explicit regularization penalty associated with multi-task learning. b, Implicit regularization penalty associated with finetuning in diagonal linear networks. c, Explicit regularization penalty associated with finetuning in ReLU networks. This penalty also depends on the changes in feature direction over finetuning (measured by the correlation between the unit-normalized feature directions pre vs. post finetuning).
  • Figure 2: PT+FT and MTL benefit from feature sparsity and reuse. a,b, Generalization loss for a) diagonal linear networks and b) ReLU networks trained on a) a linear model with distinct active dimensions and b) a teacher network with distinct units between auxiliary and main task (STL: single-task learning). MTL and PT+FT benefit from a sparser teacher on the main task. c,d, Generalization loss for c) diagonal linear networks and d) ReLU networks trained on a teacher model sharing all features between the auxiliary and main task. PT+FT and MTL both generalize better than STL. e,f, Generalization loss for e) diagonal linear networks and f) ReLU networks trained on a teacher model with overlapping features. Networks benefit from feature sharing and can learn new features.
  • Figure 3: PT+FT (much moreso than MTL) exhibits a nested feature selection regime. a-c, Diagonal linear networks. a, $\ell\text{-order}$/feature dependence plotted for $\beta_d^{main}=1$ and varying the auxiliary task feature coefficient. b, Generalization loss for models trained on a teacher with 40 active units during the auxiliary task and a subset of those units active during the main task. c, Generalization loss for PT+FT models whose weights are rescaled by the factor in the parentheses before finetuning. d-f, ReLU networks. d, $\ell\text{-order}$/feature dependence plotted for the explicit finetuning and MTL penalties, for $m=1$ and varying the auxiliary task feature coefficient. e, Generalization loss for models trained on a teacher network with six active units on the auxiliary task and a subset of those units on the main task. f, Generalization loss for PT+FT models whose weights are rescaled before finetuning.
  • Figure 4: PT+FT, but not MTL, in ReLU networks benefits from correlated features. a, Generalization loss for main task features that are correlated (0.9 cosine similarity) with the auxiliary task features. PT+FT outperforms both MTL and STL. b, Generalization loss for main task features with varying correlation and magnitude (mag.). PT+FT only outperforms STL if the features are either identical in direction or identical in magnitude.
  • Figure 5: Experiments in deep neural networks trained on CIFAR-100: a-c, ResNet-18, d-f, ViT. a,d, Accuracy for MTL, PT+FT, and STL in a) ResNet-18 and d) ViT. b,e Accuracy for PT+FT with weight rescaling in b) ResNet-18 and e) ViT. c,f The participation ration of c) ResNet-18's and f) ViT's layers before and after finetuning (PR Pre and PR Post) as well as their ENSD.
  • ...and 6 more figures

Theorems & Definitions (10)

  • Corollary 1
  • Corollary 1
  • Proposition 2
  • Proposition 3
  • Corollary 3
  • proof
  • Proposition 3
  • proof
  • Proposition 3
  • proof