Table of Contents
Fetching ...

A Kernel-Based View of Language Model Fine-Tuning

Sadhika Malladi, Alexander Wettig, Dingli Yu, Danqi Chen, Sanjeev Arora

TL;DR

This work argues that fine-tuning pre-trained language models can be understood through a kernel lens by extending the Neural Tangent Kernel to operate with pre-trained initializations and Adam optimization. It derives $A$-SignGD and SignGD kernels, formalizes when prompting induces kernel behavior, and demonstrates across 14 NLP tasks that prompt-based fine-tuning often follows kernel dynamics, with the $eNTK$ matching fine-tuning performance in many cases. The paper further shows that subspace-based fine-tuning methods like LoRA preserve the kernel under kernel-behavior conditions, offering a principled explanation for their efficacy. Overall, the kernel view provides theoretical and empirical grounding for the sample-efficient success of prompting and subspace-fine-tuning strategies, while outlining limitations and avenues for extending the framework to broader, non-early-stage training regimes.

Abstract

It has become standard to solve NLP tasks by fine-tuning pre-trained language models (LMs), especially in low-data settings. There is minimal theoretical understanding of empirical success, e.g., why fine-tuning a model with $10^8$ or more parameters on a couple dozen training points does not result in overfitting. We investigate whether the Neural Tangent Kernel (NTK) - which originated as a model to study the gradient descent dynamics of infinitely wide networks with suitable random initialization - describes fine-tuning of pre-trained LMs. This study was inspired by the decent performance of NTK for computer vision tasks (Wei et al., 2022). We extend the NTK formalism to Adam and use Tensor Programs (Yang, 2020) to characterize conditions under which the NTK lens may describe fine-tuning updates to pre-trained language models. Extensive experiments on 14 NLP tasks validate our theory and show that formulating the downstream task as a masked word prediction problem through prompting often induces kernel-based dynamics during fine-tuning. Finally, we use this kernel view to propose an explanation for the success of parameter-efficient subspace-based fine-tuning methods.

A Kernel-Based View of Language Model Fine-Tuning

TL;DR

This work argues that fine-tuning pre-trained language models can be understood through a kernel lens by extending the Neural Tangent Kernel to operate with pre-trained initializations and Adam optimization. It derives -SignGD and SignGD kernels, formalizes when prompting induces kernel behavior, and demonstrates across 14 NLP tasks that prompt-based fine-tuning often follows kernel dynamics, with the matching fine-tuning performance in many cases. The paper further shows that subspace-based fine-tuning methods like LoRA preserve the kernel under kernel-behavior conditions, offering a principled explanation for their efficacy. Overall, the kernel view provides theoretical and empirical grounding for the sample-efficient success of prompting and subspace-fine-tuning strategies, while outlining limitations and avenues for extending the framework to broader, non-early-stage training regimes.

Abstract

It has become standard to solve NLP tasks by fine-tuning pre-trained language models (LMs), especially in low-data settings. There is minimal theoretical understanding of empirical success, e.g., why fine-tuning a model with or more parameters on a couple dozen training points does not result in overfitting. We investigate whether the Neural Tangent Kernel (NTK) - which originated as a model to study the gradient descent dynamics of infinitely wide networks with suitable random initialization - describes fine-tuning of pre-trained LMs. This study was inspired by the decent performance of NTK for computer vision tasks (Wei et al., 2022). We extend the NTK formalism to Adam and use Tensor Programs (Yang, 2020) to characterize conditions under which the NTK lens may describe fine-tuning updates to pre-trained language models. Extensive experiments on 14 NLP tasks validate our theory and show that formulating the downstream task as a masked word prediction problem through prompting often induces kernel-based dynamics during fine-tuning. Finally, we use this kernel view to propose an explanation for the success of parameter-efficient subspace-based fine-tuning methods.
Paper Structure (56 sections, 10 theorems, 56 equations, 2 figures, 12 tables)

This paper contains 56 sections, 10 theorems, 56 equations, 2 figures, 12 tables.

Key Result

Theorem 4.3

If a network is trained with SignGD and exhibits kernel behavior (def:kernel_regime), then the training dynamics follow where $\chi_t$ is the output derivative (def:output_derivative).

Figures (2)

  • Figure 1: The performance difference between SGD-FT and ${\mathcal{K}}^{\text{(SGD)}}$ performance for both the standard and the prompt-based setting (\ref{['sec:prelims']}) suggests that using a prompt is important for kernel behavior (\ref{['def:kernel_regime']}) to arise. In standard FT, we initialize the new classification head (i.e., $\Gamma$) using the linear probing solution. The performance is shown for the $64$-shot setting and measured by the average test accuracy over 5 random splits, except for MRPC and QQP, where it is F1. Results on additional settings are in \ref{['tab:noprompt']}.
  • Figure 2: Accuracies of zero-shot pre-trained model (PT), linearized model (Lin., see \ref{['def:kernel_regime']}) and fine-tuned model (FT). Tasks that induce the Linearization property of kernel behavior (\ref{['def:kernel_regime']}) will show that Lin. performance recovers a substantial amount of the performance of SGD-FT and Adam-FT respectively. We plot the median and range of the test accuracies across 5 seeds and data splits for $k=64$.

Theorems & Definitions (34)

  • Definition 3.1: Output Derivative
  • Definition 3.2: Kernel Behavior
  • Definition 3.3: Kernel Analog
  • Definition 3.4: Neural Tangent Kernel ${\mathcal{K}}^{\text{(SGD)}}$
  • Definition 4.1: SignGD
  • Definition 4.2: Asymmetric SignGD Kernel
  • Theorem 4.3: Informal version of \ref{['thm:theory_signgd_kernel']}
  • proof : Proof sketch
  • Definition 4.4: SignGD Kernel
  • Definition 5.1: Pre-Training Scheme
  • ...and 24 more