Table of Contents
Fetching ...

On the Emergence of Cross-Task Linearity in the Pretraining-Finetuning Paradigm

Zhanpeng Zhou, Zijun Chen, Yilan Chen, Bo Zhang, Junchi Yan

TL;DR

The paper identifies Cross-Task Linearity (CTL), a cross-task extension of Layerwise Linear Feature Connectivity (LLFC), as a prevalent phenomenon when finetuning from a shared pretrained checkpoint on different tasks. It shows that linear interpolations of finetuned weights yield features that closely follow the linear interpolation of their layerwise features, suggesting an approximate linear map from parameter space to feature space. CTL is leveraged to explain model merging/editing techniques such as model averaging and task arithmetic, by translating parameter-space operations into the feature space. The authors also explore root causes, linking CTL to pretraining depth and task similarity, and provide a preliminary theoretical bound related to loss flatness and weight distance, with notes on future work for theory and large-language-model settings.

Abstract

The pretraining-finetuning paradigm has become the prevailing trend in modern deep learning. In this work, we discover an intriguing linear phenomenon in models that are initialized from a common pretrained checkpoint and finetuned on different tasks, termed as Cross-Task Linearity (CTL). Specifically, we show that if we linearly interpolate the weights of two finetuned models, the features in the weight-interpolated model are often approximately equal to the linear interpolation of features in two finetuned models at each layer. We provide comprehensive empirical evidence supporting that CTL consistently occurs for finetuned models that start from the same pretrained checkpoint. We conjecture that in the pretraining-finetuning paradigm, neural networks approximately function as linear maps, mapping from the parameter space to the feature space. Based on this viewpoint, our study unveils novel insights into explaining model merging/editing, particularly by translating operations from the parameter space to the feature space. Furthermore, we delve deeper into the root cause for the emergence of CTL, highlighting the role of pretraining.

On the Emergence of Cross-Task Linearity in the Pretraining-Finetuning Paradigm

TL;DR

The paper identifies Cross-Task Linearity (CTL), a cross-task extension of Layerwise Linear Feature Connectivity (LLFC), as a prevalent phenomenon when finetuning from a shared pretrained checkpoint on different tasks. It shows that linear interpolations of finetuned weights yield features that closely follow the linear interpolation of their layerwise features, suggesting an approximate linear map from parameter space to feature space. CTL is leveraged to explain model merging/editing techniques such as model averaging and task arithmetic, by translating parameter-space operations into the feature space. The authors also explore root causes, linking CTL to pretraining depth and task similarity, and provide a preliminary theoretical bound related to loss flatness and weight distance, with notes on future work for theory and large-language-model settings.

Abstract

The pretraining-finetuning paradigm has become the prevailing trend in modern deep learning. In this work, we discover an intriguing linear phenomenon in models that are initialized from a common pretrained checkpoint and finetuned on different tasks, termed as Cross-Task Linearity (CTL). Specifically, we show that if we linearly interpolate the weights of two finetuned models, the features in the weight-interpolated model are often approximately equal to the linear interpolation of features in two finetuned models at each layer. We provide comprehensive empirical evidence supporting that CTL consistently occurs for finetuned models that start from the same pretrained checkpoint. We conjecture that in the pretraining-finetuning paradigm, neural networks approximately function as linear maps, mapping from the parameter space to the feature space. Based on this viewpoint, our study unveils novel insights into explaining model merging/editing, particularly by translating operations from the parameter space to the feature space. Furthermore, we delve deeper into the root cause for the emergence of CTL, highlighting the role of pretraining.
Paper Structure (27 sections, 10 theorems, 33 equations, 24 figures, 1 table)

This paper contains 27 sections, 10 theorems, 33 equations, 24 figures, 1 table.

Key Result

Theorem 3.3

Given dataset $\mathcal{D}$, convex loss function $L$, and two modes ${\boldsymbol{\theta}}_i$ and ${\boldsymbol{\theta}}_j$ with equal loss on $\mathcal{D}$, i.e., $\mathcal{L}({\boldsymbol{\theta}}_i) = \mathcal{L}({\boldsymbol{\theta}}_j)$, suppose the two modes ${\boldsymbol{\theta}}_i$, ${\bold

Figures (24)

  • Figure 1: The spawning method and the pretraining-finetuning paradigm. $\boldsymbol{\theta}^{0}$ denotes random initialization of the network weights. For spawning, the network is first trained for $k$ epochs to get $\boldsymbol{\theta}^{k}$, then spawned into two copies and updated until convergence to get $\boldsymbol{\theta}_1^{T}, \boldsymbol{\theta}_2^{T}$. Note $\boldsymbol{\theta}_1^{T}, \boldsymbol{\theta}_2^{T}$ are trained on same task but with different SGD noise. With a proper chosen $k$ , $\boldsymbol{\theta}_1^{T}$ and $\boldsymbol{\theta}_2^{T}$ can satisfy LMC and LLFC. For pretraining-finetuning, the network is first trained on pretraining task $\mathcal{D}_{\rm PT}$ to get $\boldsymbol{\theta}_{\rm PT}$. Then $\boldsymbol{\theta}_{\rm PT}$ is finetuned on $\mathcal{D}_i$ and $\mathcal{D}_j$ to get $\boldsymbol{\theta}_i$ and $\boldsymbol{\theta}_j$. $\mathcal{D}_i$ and $\mathcal{D}_j$ can be different.
  • Figure 2: Verification of CTL. Compare $\mathbb{E}_{\mathcal{D}}[1-{\rm cosine}_{\alpha}^{(\ell)}(\boldsymbol{x})]$ with $\mathbb{E}_{\mathcal{D}}[1-\text{cosine}_{i,j}^{(\ell)}(\boldsymbol{x})]$. Here, $\{\boldsymbol{\theta}_i\}_{i=1}^3$ and $\{\mathcal{D}_i\}_{i=1}^3$ denotes finetuned models and their corresponding downstream tasks. For Rotated MNIST, models are pretrained on MNIST and finetuned on variants of MNIST where digits are at different angles. For Split CIFAR-100, models are pretrained and finetuned on disjoint sets of 5 classes from CIFAR-100. The bottom and top of the error bar represent the lower and upper quartile of the values across the dataset, respectively. The results are reported for last three layers/blocks, with $\alpha \in \{0.25, 0.5, 0.75\}$. More results in \ref{['suppl:exp_CTL']}.
  • Figure 3: Verification of CTL. Distribution of $\text{coef}_{\alpha}^{(\ell)}(\boldsymbol{x})$ across the dataset. Here, $\alpha = 0.5$. $\{\boldsymbol{\theta}_i\}_{i=1}^3$ and $\{\mathcal{D}_i\}_{i=1}^3$ denotes finetuned models and their corresponding downstream tasks. We follow the same training settings as in \ref{['fig:LLFC_main_cosine']}. The results are reported for last three layers/blocks. More results in \ref{['suppl:exp_CTL']}.
  • Figure 4: Linear correlation between the model averaging accuracy and the logits ensemble accuracy. Each datapoint represents three models fine-tuned on ImageNet with varying hyperparameters, denoted as $\{\boldsymbol{\theta}\}_{i=1}^3$. The x-axis represents accuracy of $f(\frac{1}{3}\sum_{i=1}^3\boldsymbol{\theta}_i)$, while the y-axis represents accuracy of $\frac{1}{3}\sum_{i=1}^3f(\boldsymbol{\theta}_i)$. The grey dashed line represents $y=x$.
  • Figure 5: Verification of CTL in model averaging. Compare $\mathbb{E}_{\mathcal{D}}[1-{\rm cosine}_{avg}^{(\ell)}(\boldsymbol{x})]$ with $\mathbb{E}_{\mathcal{D}}[1-{\rm cosine}_{base}^{(\ell)}(\boldsymbol{x})]$. The bottom and top of the error bar represent the lower and upper quartile of the values across the dataset, respectively. The results are reported for last $7$ blocks of ViT-B/32 models, that are finetuned on CIFAR-10 and ImageNet, respectively. More results in \ref{['suppl:exp_model_avg']}.
  • ...and 19 more figures

Theorems & Definitions (21)

  • Definition 3.1: Linear Mode Connectivity nagarajan2019uniformfrankle2020linear
  • Definition 3.2: Layerwise Linear Feature Connectivity zhou2023going
  • Theorem 3.3: LLFC Induces LMC (Proof in \ref{['suppl:proof_of_thm1']})
  • Conjecture 4.1: Transitivity of CTL.
  • Theorem 4.2: CTL Generalizes to Multiple Models (Proof in \ref{['suppl:proof_of_thm2']})
  • Theorem 4.3: CTL Connects to Weight Disentanglement. (Proof in \ref{['suppl:proof_of_thm3']})
  • Theorem 5.1: The Emergence of CTL (Proof in \ref{['suppl:proof_of_thm1']})
  • Remark 5.2
  • Definition 2.1: Transitivity of CTL.
  • Lemma 2.2: CTL holds for two-model weight interpolations.
  • ...and 11 more