Table of Contents
Fetching ...

Towards Infinite-Long Prefix in Transformer

Yingyu Liang, Zhenmei Shi, Zhao Song, Chiwun Yang

TL;DR

This paper provides a convergence guarantee for training an ultra-long prefix in a stylized setting using the Neural Tangent Kernel (NTK) framework and designs and implements an algorithm that only needs to introduce and fine-tune a few extra trainable parameters instead of an infinite-long prefix in each layer of a transformer.

Abstract

Prompting and context-based fine-tuning methods, which we call Prefix Learning, have been proposed to enhance the performance of language models on various downstream tasks. They are empirically efficient and effective, matching the performance of full parameter fine-tuning, but the theoretical understandings are limited. In this paper, we aim to address this limitation by studying their ability from the perspective of prefix length. In particular, we provide a convergence guarantee for training an ultra-long prefix in a stylized setting using the Neural Tangent Kernel (NTK) framework. Based on this strong theoretical guarantee, we design and implement an algorithm that only needs to introduce and fine-tune a few extra trainable parameters instead of an infinite-long prefix in each layer of a transformer, and can approximate the prefix attention to a guaranteed polynomial-small error. Preliminary experimental results on vision, natural language, and math data show that our method achieves superior or competitive performance compared to existing methods like full parameters fine-tuning, P-Tuning V2, and LoRA. This demonstrates our method is promising for parameter-efficient fine-tuning. Our code can be found at \url{https://github.com/ChristianYang37/chiwun/tree/main/src/NTK-Attention}.

Towards Infinite-Long Prefix in Transformer

TL;DR

This paper provides a convergence guarantee for training an ultra-long prefix in a stylized setting using the Neural Tangent Kernel (NTK) framework and designs and implements an algorithm that only needs to introduce and fine-tune a few extra trainable parameters instead of an infinite-long prefix in each layer of a transformer.

Abstract

Prompting and context-based fine-tuning methods, which we call Prefix Learning, have been proposed to enhance the performance of language models on various downstream tasks. They are empirically efficient and effective, matching the performance of full parameter fine-tuning, but the theoretical understandings are limited. In this paper, we aim to address this limitation by studying their ability from the perspective of prefix length. In particular, we provide a convergence guarantee for training an ultra-long prefix in a stylized setting using the Neural Tangent Kernel (NTK) framework. Based on this strong theoretical guarantee, we design and implement an algorithm that only needs to introduce and fine-tune a few extra trainable parameters instead of an infinite-long prefix in each layer of a transformer, and can approximate the prefix attention to a guaranteed polynomial-small error. Preliminary experimental results on vision, natural language, and math data show that our method achieves superior or competitive performance compared to existing methods like full parameters fine-tuning, P-Tuning V2, and LoRA. This demonstrates our method is promising for parameter-efficient fine-tuning. Our code can be found at \url{https://github.com/ChristianYang37/chiwun/tree/main/src/NTK-Attention}.
Paper Structure (53 sections, 35 theorems, 149 equations, 3 figures, 2 tables, 4 algorithms)

This paper contains 53 sections, 35 theorems, 149 equations, 3 figures, 2 tables, 4 algorithms.

Key Result

Lemma 3.1

For $\delta \in (0, 0.1)$ and $B = \max \{ C\sigma \sqrt{\log(nd/\delta)}, 1\}$. Let $\widetilde{W} = [\widetilde{w}_1, \cdots, \widetilde{w}_m] \in \mathbb{R}^{d\times m}$ and satisfy $\| \widetilde{w}_r - w_r(0) \|_2 \leq R$ for any $r\in [m]$, where $R$ is some constant in $(0, 0.01)$. Define $\w

Figures (3)

  • Figure 1: Illustration of existing prefix attention methods (Algorithm \ref{['alg:attn']}) and our NTK-Attention (Algorithm \ref{['alg:ntk_attn']}). Compared to the former, NTK-Attention significantly reduces the number of parameters and the time complexity. Here, $X \in \mathbb{R}^{L\times d}$ is the input of this layer, $W = [W_Q, W_K, W_V]$ is frozen weights of attention, $P \in \mathbb{R}^{m \times d}$ is the trainable prefix matrix and $Z \in \mathbb{R}^{r \times d}, k \in \mathbb{R}^r$ are the trainable parameters in our method. $L$ is the input length, $d$ the input dimension, $m$ the prefix length, and $r$ a hyperparameter in NTK-attention (i.e., the dimension of the constructed feature mapping; see Section \ref{['sec:ntk_attn']}). Note that $m \gg L$, and $r=d$ is used in our experiments.
  • Figure 2: Compare our results with LoRA and Zero-Shot on Math inference datasets. The $y$-axis is the accuracy.
  • Figure 3: Run time and number of parameters of One-layer NTK-Attention and Prefix Attention (on random input data). $x$-axis: the number of parameters; $y$-axis: run time. Input length $L$ is chosen from $\{32, 64, 128, 256\}$, dimension $d=32$ and prefix length $m$ is chosen from $\{2^0, 2^1, \cdots, 2^{16}\}$.

Theorems & Definitions (87)

  • Lemma 3.1: Kernel convergence, informal version of Lemma \ref{['lem:perturb_w_formal']}
  • Theorem 3.2: Main result, informal version of Theorem \ref{['thm:main_formal']}
  • proof : Proof sketch of Theorem \ref{['thm:main_informal']}
  • Proposition 3.3: Scaling Law in Prefix Learning
  • proof : Proof sketch of Proposition \ref{['pro:scaling_law']}
  • Theorem 4.1: Error bound with reduced time complexity, informal version of Theorem \ref{['thm:error_bound_on_ntk_attn']}
  • Lemma D.3: Chernoff bound, che52
  • Lemma D.4: Hoeffding bound, hoe44
  • Lemma D.5: Bernstein inequality, ber24
  • Lemma D.6: Khintchine’s inequality, khi23haa81
  • ...and 77 more