Table of Contents
Fetching ...

Understanding Linear Probing then Fine-tuning Language Models from NTK Perspective

Akiyoshi Tomihari, Issei Sato

TL;DR

This paper analyzes the training dynamics of LP-FT for classification tasks on the basis of the neural tangent kernel (NTK) theory and decomposes the NTK matrix into two components, highlighting the importance of the linear head norm alongside the prediction accuracy at the start of the FT stage.

Abstract

The two-stage fine-tuning (FT) method, linear probing (LP) then fine-tuning (LP-FT), outperforms linear probing and FT alone. This holds true for both in-distribution (ID) and out-of-distribution (OOD) data. One key reason for its success is the preservation of pre-trained features, achieved by obtaining a near-optimal linear head during LP. However, despite the widespread use of large language models, there has been limited exploration of more complex architectures such as Transformers. In this paper, we analyze the training dynamics of LP-FT for classification tasks on the basis of the neural tangent kernel (NTK) theory. Our analysis decomposes the NTK matrix into two components. This decomposition highlights the importance of the linear head norm alongside the prediction accuracy at the start of the FT stage. We also observe a significant increase in the linear head norm during LP, which stems from training with the cross-entropy (CE) loss. This increase in the linear head norm effectively reduces changes in learned features. Furthermore, we find that this increased norm can adversely affect model calibration, which can be corrected using temperature scaling. Additionally, we extend our analysis with the NTK to the low-rank adaptation (LoRA) method and validate its effectiveness. Our experiments using a Transformer-based model on multiple natural language processing datasets confirm our theoretical analysis. Our study demonstrates the effectiveness of LP-FT for fine-tuning language models. Code is available at https://github.com/tom4649/lp-ft_ntk.

Understanding Linear Probing then Fine-tuning Language Models from NTK Perspective

TL;DR

This paper analyzes the training dynamics of LP-FT for classification tasks on the basis of the neural tangent kernel (NTK) theory and decomposes the NTK matrix into two components, highlighting the importance of the linear head norm alongside the prediction accuracy at the start of the FT stage.

Abstract

The two-stage fine-tuning (FT) method, linear probing (LP) then fine-tuning (LP-FT), outperforms linear probing and FT alone. This holds true for both in-distribution (ID) and out-of-distribution (OOD) data. One key reason for its success is the preservation of pre-trained features, achieved by obtaining a near-optimal linear head during LP. However, despite the widespread use of large language models, there has been limited exploration of more complex architectures such as Transformers. In this paper, we analyze the training dynamics of LP-FT for classification tasks on the basis of the neural tangent kernel (NTK) theory. Our analysis decomposes the NTK matrix into two components. This decomposition highlights the importance of the linear head norm alongside the prediction accuracy at the start of the FT stage. We also observe a significant increase in the linear head norm during LP, which stems from training with the cross-entropy (CE) loss. This increase in the linear head norm effectively reduces changes in learned features. Furthermore, we find that this increased norm can adversely affect model calibration, which can be corrected using temperature scaling. Additionally, we extend our analysis with the NTK to the low-rank adaptation (LoRA) method and validate its effectiveness. Our experiments using a Transformer-based model on multiple natural language processing datasets confirm our theoretical analysis. Our study demonstrates the effectiveness of LP-FT for fine-tuning language models. Code is available at https://github.com/tom4649/lp-ft_ntk.
Paper Structure (60 sections, 4 theorems, 33 equations, 42 figures, 16 tables)

This paper contains 60 sections, 4 theorems, 33 equations, 42 figures, 16 tables.

Key Result

Proposition 1

The NTK of a model $\bm{f}(\bm{x})=\bm{V}\bm{\phi}(\bm{x})+\bm{b}$, denoted by $\Theta^{\bm{f}}$, can be decomposed as: where the pre-train-effective component $\bm{P}(\bm{x}, \bm{x}_i)$ and the FT-effective component $\bm{F}(\bm{x}, \bm{x}_i)$ are defined using the classifier weight matrix $\bm{V}_0$ and the feature extractor $\bm{\phi}_0$ at starting point of training as: Consequently, assumin

Figures (42)

  • Figure 1: LP
  • Figure 2: FT
  • Figure 3: After training
  • Figure 5: Singular value distribution normalized by the maximum value on the CB dataset, showing the common pre-train-effective component (Pre-train E) and the FT-effective components for each training option.
  • Figure 6: Difference of features on SST-5 (OOD). The dashed vertical lines indicate the original classifier weight norm after training.
  • ...and 37 more figures

Theorems & Definitions (8)

  • Proposition 1: FT in the NTK regime
  • Definition 1: Linear model kumar2022fine
  • Corollary 1: Lemma A.3 from Kumar et al. in the NTK regime
  • Proposition 2: LoRA approximates FT
  • proof
  • proof
  • Lemma 1: Corollary of the distributional Johnson-Lindenstrauss Lemma
  • proof