Table of Contents
Fetching ...

Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods

Tsachi Blau, Moshe Kimhi, Yonatan Belinkov, Alexander Bronstein, Chaim Baskin

TL;DR

Context-aware Prompt Tuning (CPT) is introduced, a method inspired by ICL, PT, and adversarial attacks that builds on the ICL strategy of concatenating examples before the input, but extends this by PT-like learning, refining the context embedding through iterative optimization to extract deeper insights from the training examples.

Abstract

Fine-tuning Large Language Models (LLMs) typically involves updating at least a few billions of parameters. A more parameter-efficient approach is Prompt Tuning (PT), which updates only a few learnable tokens, and differently, In-Context Learning (ICL) adapts the model to a new task by simply including examples in the input without any training. When applying optimization-based methods, such as fine-tuning and PT for few-shot learning, the model is specifically adapted to the small set of training examples, whereas ICL leaves the model unchanged. This distinction makes traditional learning methods more prone to overfitting; in contrast, ICL is less sensitive to the few-shot scenario. While ICL is not prone to overfitting, it does not fully extract the information that exists in the training examples. This work introduces Context-aware Prompt Tuning (CPT), a method inspired by ICL, PT, and adversarial attacks. We build on the ICL strategy of concatenating examples before the input, but we extend this by PT-like learning, refining the context embedding through iterative optimization to extract deeper insights from the training examples. We carefully modify specific context tokens, considering the unique structure of input and output formats. Inspired by adversarial attacks, we adjust the input based on the labels present in the context, focusing on minimizing, rather than maximizing, the loss. Moreover, we apply a projected gradient descent algorithm to keep token embeddings close to their original values, under the assumption that the user-provided data is inherently valuable. Our method has been shown to achieve superior accuracy across multiple classification tasks using various LLM models.

Context-aware Prompt Tuning: Advancing In-Context Learning with Adversarial Methods

TL;DR

Context-aware Prompt Tuning (CPT) is introduced, a method inspired by ICL, PT, and adversarial attacks that builds on the ICL strategy of concatenating examples before the input, but extends this by PT-like learning, refining the context embedding through iterative optimization to extract deeper insights from the training examples.

Abstract

Fine-tuning Large Language Models (LLMs) typically involves updating at least a few billions of parameters. A more parameter-efficient approach is Prompt Tuning (PT), which updates only a few learnable tokens, and differently, In-Context Learning (ICL) adapts the model to a new task by simply including examples in the input without any training. When applying optimization-based methods, such as fine-tuning and PT for few-shot learning, the model is specifically adapted to the small set of training examples, whereas ICL leaves the model unchanged. This distinction makes traditional learning methods more prone to overfitting; in contrast, ICL is less sensitive to the few-shot scenario. While ICL is not prone to overfitting, it does not fully extract the information that exists in the training examples. This work introduces Context-aware Prompt Tuning (CPT), a method inspired by ICL, PT, and adversarial attacks. We build on the ICL strategy of concatenating examples before the input, but we extend this by PT-like learning, refining the context embedding through iterative optimization to extract deeper insights from the training examples. We carefully modify specific context tokens, considering the unique structure of input and output formats. Inspired by adversarial attacks, we adjust the input based on the labels present in the context, focusing on minimizing, rather than maximizing, the loss. Moreover, we apply a projected gradient descent algorithm to keep token embeddings close to their original values, under the assumption that the user-provided data is inherently valuable. Our method has been shown to achieve superior accuracy across multiple classification tasks using various LLM models.

Paper Structure

This paper contains 19 sections, 2 equations, 6 figures, 2 tables.

Figures (6)

  • Figure 1: Overfitting Comparison: CPT vs. Baselines Visualizing the train-test loss gap across various methods and training set sizes using the GPT-j model on the DBpedia dataset. For each model, there are two loss graphs: one for train loss (dotted line) and one for test loss (solid line). CPT performs better in mitigating overfitting compared to optimization-based methods. Despite a relatively higher training loss, CPT achieves the lowest test loss.
  • Figure 2: Comparison of Baseline Algorithms and Token Utilization. We highlight the key differences between CPT and the baselines, focusing on ICL, PT, and IPT. For each method, we emphasize two types of tokens: those used for loss calculation (red line) and those updated during optimization (green line). CPT features $\emph{'Learnable Sample}_i'$ in dual colors, reflecting their progression from $'\emph{Sample tokens}'$ to $'\emph{Learnable tokens}'$ as they are optimized.
  • Figure 3: Few-Shot Methods Comparison. We compare CPT using the GPT-j model and the DBpedia dataset to baselines in few-shot settings, showing that it particularly excels when dealing with a limited number of examples. Additionally, we show that context-based methods hit memory constraints (marked with a dot) as the number of training examples rises beyond a certain level.
  • Figure 4: Overview of CPT Training Process. We begin by arranging the data. We concatenate all of the training examples that were embedded into the input-output templates $[X_{\text{Emb}_i}]_{i=1}^{N}$, creating $X_{\text{Context}}$. To this, we append a randomly selected training example, in this case $X_{\text{Emb}_2}$, to form the complete training example $X_{\text{Train}_2}$. For the training process, the input is passed through the frozen LLM, and the loss is calculated using all labels present in $X_{\text{Train}_2}$, covering both the context and training labels. The context is updated, but its labels remain unchanged.
  • Figure 5: Set Classification Dataset Mean accuracy of our CPT method versus baselines as the number of training examples increases using the GPT-j model.
  • ...and 1 more figures