Table of Contents
Fetching ...

Fine-Tuning Without Forgetting In-Context Learning: A Theoretical Analysis of Linear Attention Models

Chungpa Lee, Jy-yong Sohn, Kangwook Lee

TL;DR

It is shown that fine-tuning all attention parameters can harm in-context learning, whereas restricting updates to the value matrix improves zero-shot performance while preserving in-context learning, and incorporating an auxiliary few-shot loss enhances in-context learning primarily on the target task, at the expense of degraded in-context learning ability on tasks not seen during fine-tuning.

Abstract

Transformer-based large language models exhibit in-context learning, enabling adaptation to downstream tasks via few-shot prompting with demonstrations. In practice, such models are often fine-tuned to improve zero-shot performance on downstream tasks, allowing them to solve tasks without examples and thereby reducing inference costs. However, fine-tuning can degrade in-context learning, limiting the performance of fine-tuned models on tasks not seen during fine-tuning. Using linear attention models, we provide a theoretical analysis that characterizes how fine-tuning objectives modify attention parameters and identifies conditions under which this leads to degraded few-shot performance. We show that fine-tuning all attention parameters can harm in-context learning, whereas restricting updates to the value matrix improves zero-shot performance while preserving in-context learning. We further show that incorporating an auxiliary few-shot loss enhances in-context learning primarily on the target task, at the expense of degraded in-context learning ability on tasks not seen during fine-tuning. We empirically validate our theoretical results.

Fine-Tuning Without Forgetting In-Context Learning: A Theoretical Analysis of Linear Attention Models

TL;DR

It is shown that fine-tuning all attention parameters can harm in-context learning, whereas restricting updates to the value matrix improves zero-shot performance while preserving in-context learning, and incorporating an auxiliary few-shot loss enhances in-context learning primarily on the target task, at the expense of degraded in-context learning ability on tasks not seen during fine-tuning.

Abstract

Transformer-based large language models exhibit in-context learning, enabling adaptation to downstream tasks via few-shot prompting with demonstrations. In practice, such models are often fine-tuned to improve zero-shot performance on downstream tasks, allowing them to solve tasks without examples and thereby reducing inference costs. However, fine-tuning can degrade in-context learning, limiting the performance of fine-tuned models on tasks not seen during fine-tuning. Using linear attention models, we provide a theoretical analysis that characterizes how fine-tuning objectives modify attention parameters and identifies conditions under which this leads to degraded few-shot performance. We show that fine-tuning all attention parameters can harm in-context learning, whereas restricting updates to the value matrix improves zero-shot performance while preserving in-context learning. We further show that incorporating an auxiliary few-shot loss enhances in-context learning primarily on the target task, at the expense of degraded in-context learning ability on tasks not seen during fine-tuning. We empirically validate our theoretical results.
Paper Structure (35 sections, 29 theorems, 151 equations, 8 figures, 2 tables)

This paper contains 35 sections, 29 theorems, 151 equations, 8 figures, 2 tables.

Key Result

Corollary 4.1

Suppose that ${\mathbf{V}}$ and ${\mathbf{Q}}$ in eq:linear:attention:model satisfy $q \ge 0$ and that ${\mathbf{Q}}_{11}$ is positive definite. Consider the loss $\mathcal{L}({\mathbf{V}},{\mathbf{Q}})$ in eq:pretrain:loss with context length $m\in{\mathbb{N}}$. Then, for sufficiently large $m$, th All unspecified blocks (denoted by $\cdot$) are zero. Moreover, ${\mathbf{U}} \mathop{\mathrm{diag}

Figures (8)

  • Figure 1: $n$-shot test error on the target task ${\mathbf{\theta}}_0$ of the pretrained model using the theory-derived parameters in Theorem \ref{['thm:optimal:pretrain']}. Each curve corresponds to a model pretrained with context length $m=1000$ and a different input dimension $d$. We set $\sigma^2=0$, under which the condition in \ref{['eq:condition:large:n:pretrain:reduced']} predicts that the $n$-shot performance is worse than zero-shot performance for $n$ ranging from $1$ to $d-2$.
  • Figure 2: $n$-shot test error on the target task ${\mathbf{\theta}}_0$ of the fully fine-tuned model using the theory-derived parameters in Theorem \ref{['thm:optimal:full-finetune']}. Each curve corresponds to a model with a different choice of the parameter $w$, evaluated under $\sigma^2=0.1$ and ${\mathbf{\theta}}_0^\top{\mathbf{\Sigma}}{\mathbf{\theta}}_0=1$. For all values of $w$, the zero-shot error is $0.1$, whereas the few-shot error is larger than the zero-shot error and converges to $1$ as $n \to \infty$.
  • Figure 3: (Top) $n$-shot test error on the target task ${\mathbf{\theta}}_0$ for the value-matrix fine-tuned model with the theory-derived parameters $(\hat{{\mathbf{V}}}(w), \hat{{\mathbf{Q}}})$ in Theorem \ref{['thm:optimal:value-matrix']}. Each curve corresponds to a different choice of $w$. We use $d=5$, ${\mathbf{\theta}}_0^\top{\mathbf{\Sigma}}{\mathbf{\theta}}_0 = 1$, and $\sigma^2 = 0.1$, for which the zero-shot test error equals $3.1\,(= \sigma^2 +\tfrac{2}{d+4}{\mathbf{\theta}}_0^\top{\mathbf{\Sigma}}{\mathbf{\theta}}_0 )$ for all models. (Bottom) Optimal $w^\star(n;{\mathbf{\theta}}_0)$ that minimizes the $n$-shot test error on the target task ${\mathbf{\theta}}_0$. The optimal value increases monotonically with $n$ and converges to $8/9 \,(= \tfrac{d+3}{d+4})$. For example, the $20$-shot test error is minimized at $w \approx 0.66$.
  • Figure 4: $n$-shot test errors on the target task ${\mathbf{\theta}}_0$ for linear attention models under different training regimes. Curves show theoretical predictions from Section \ref{['sec:theory:optimal']}, while points correspond to models trained empirically in Section \ref{['sec:experiment:regression']}. Empirical results match the theoretical values. The parameters used for the theoretical predictions are given by: (a) \ref{['eq:parameter_pretrain']} from Corollary \ref{['thm:optimal:pretrain']}. (b) \ref{['eq:finetune:zs:parameter']} from Theorem \ref{['thm:optimal:full-finetune']} with $w=0.52$. (c) \ref{['eq:parameter:value:only']} from Theorem \ref{['thm:optimal:value-matrix']} with $w=0.77\;(=\tfrac{m}{m+1+d})$ (Proposition \ref{['thm:optimal:w:zs']}). (d) \ref{['eq:parameter:value:only']} with $w=0.66$ (Theorem \ref{['thm:optimal:w']}; see Figure \ref{['fig:optimal_zs_finetune_v_only']}).
  • Figure 5: $n$-shot test errors evaluated on the in-distribution task ${\mathbf{\theta}}_0$(Left) and the out-of-distribution task $-{\mathbf{\theta}}_0$(Right) . Empty circles denote models fine-tuned with the zero-shot (ZS) loss only, whereas filled circles denote models fine-tuned with both the zero-shot and few-shot (FS) losses. Parentheses indicate which parameters are updated: all updates all parameters, Q updates the query--key matrix, and V updates the value matrix. Value-matrix fine-tuning without the auxiliary few-shot loss gives the lowest error on the out-of-distribution task $-{\mathbf{\theta}}_0$.
  • ...and 3 more figures

Theorems & Definitions (46)

  • Definition 3.1: Test Error
  • Corollary 4.1: Optimal Parameters of Pretrained Models; Corollary of Theorem 1 in ahn2023transformers
  • Corollary 4.2: Test Error of Pretrained Models
  • Theorem 4.3: Optimal Parameters of Fully Fine-Tuned Models
  • Corollary 4.4: Test Error of Fully Fine-Tuned Models
  • Theorem 4.5
  • Theorem 4.6: Optimal Parameters and Test Error of Value-Matrix Fine-Tuned Models
  • Proposition 4.7
  • Corollary 4.8: Task-Averaged Optimal $w$
  • Theorem 4.9: Task-Wise Optimal $w$
  • ...and 36 more