Table of Contents
Fetching ...

Linear-Time Demonstration Selection for In-Context Learning via Gradient Estimation

Ziniu Zhang, Zhenshuo Zhang, Dongyue Li, Lu Wang, Jennifer Dy, Hongyang R. Zhang

TL;DR

This work addresses efficient demonstration selection for in-context learning by introducing a gradient-based, first-order loss estimation approach in the input-embedding space. By performing a single precomputation of model outputs and gradients around an anchor prompt $\phi(S_0,x)$ and applying a first-order Taylor expansion, the method estimates losses for many candidate subsets, enabling a linear-time selection procedure. Empirically, the technique achieves less than $1\%$ approximation error across models up to $34$B parameters and delivers up to $37.7\times$ speed-ups, while improving ICL performance by about $11\%$ on six datasets and maintaining strong results with much shorter context in long-context settings. The approach scales demonstration selection to large prompt pools and offers practical benefits for prompt tuning and chain-of-thought reasoning, with an open-source implementation available.

Abstract

This paper introduces an algorithm to select demonstration examples for in-context learning of a query set. Given a set of $n$ examples, how can we quickly select $k$ out of $n$ to best serve as the conditioning for downstream inference? This problem has broad applications in prompt tuning and chain-of-thought reasoning. Since model weights remain fixed during in-context learning, previous work has sought to design methods based on the similarity of token embeddings. This work proposes a new approach based on gradients of the output taken in the input embedding space. Our approach estimates model outputs through a first-order approximation using the gradients. Then, we apply this estimation to multiple randomly sampled subsets. Finally, we aggregate the sampled subset outcomes to form an influence score for each demonstration, and select $k$ most relevant examples. This procedure only requires pre-computing model outputs and gradients once, resulting in a linear-time algorithm relative to model and training set sizes. Extensive experiments across various models and datasets validate the efficiency of our approach. We show that the gradient estimation procedure yields approximations of full inference with less than ${1}\%$ error across six datasets. This allows us to scale up subset selection that would otherwise run full inference by up to ${37.7}\times$ on models with up to $34$ billion parameters, and outperform existing selection methods based on input embeddings by ${11}\%$ on average.

Linear-Time Demonstration Selection for In-Context Learning via Gradient Estimation

TL;DR

This work addresses efficient demonstration selection for in-context learning by introducing a gradient-based, first-order loss estimation approach in the input-embedding space. By performing a single precomputation of model outputs and gradients around an anchor prompt and applying a first-order Taylor expansion, the method estimates losses for many candidate subsets, enabling a linear-time selection procedure. Empirically, the technique achieves less than approximation error across models up to B parameters and delivers up to speed-ups, while improving ICL performance by about on six datasets and maintaining strong results with much shorter context in long-context settings. The approach scales demonstration selection to large prompt pools and offers practical benefits for prompt tuning and chain-of-thought reasoning, with an open-source implementation available.

Abstract

This paper introduces an algorithm to select demonstration examples for in-context learning of a query set. Given a set of examples, how can we quickly select out of to best serve as the conditioning for downstream inference? This problem has broad applications in prompt tuning and chain-of-thought reasoning. Since model weights remain fixed during in-context learning, previous work has sought to design methods based on the similarity of token embeddings. This work proposes a new approach based on gradients of the output taken in the input embedding space. Our approach estimates model outputs through a first-order approximation using the gradients. Then, we apply this estimation to multiple randomly sampled subsets. Finally, we aggregate the sampled subset outcomes to form an influence score for each demonstration, and select most relevant examples. This procedure only requires pre-computing model outputs and gradients once, resulting in a linear-time algorithm relative to model and training set sizes. Extensive experiments across various models and datasets validate the efficiency of our approach. We show that the gradient estimation procedure yields approximations of full inference with less than error across six datasets. This allows us to scale up subset selection that would otherwise run full inference by up to on models with up to billion parameters, and outperform existing selection methods based on input embeddings by on average.

Paper Structure

This paper contains 21 sections, 1 theorem, 6 equations, 4 figures, 12 tables, 4 algorithms.

Key Result

lemma 1

Let $0 < \varepsilon < 1$, let $X$ be a set of $N$ points in $\mathbb{R}^n$, and let $k \;\geq\; C \,\varepsilon^{-2} \log N,$ for some universal constant $C > 0$. Then there exists a linear map $f : \mathbb{R}^n \to \mathbb{R}^k$ such that for all $u,v \in X$,

Figures (4)

  • Figure 1: Given a set of demonstrations, we design a linear-time demonstration selection algorithm to construct prompts for in-context learning of a query set. Stage 1: First, pre-compute functional outputs and gradients (with respect to the embedding vector) on the entire training set. Stage 2: Second, apply a first-order approximation to estimate the model outputs on a list of $m$ random subsets $S_1, S_2, \dots, S_m$. This approximation is computed based on the model outputs and the gradients computed during Stage 1. Let $\hat{h}(S_1), \hat{h}(S_2), \ldots, \hat{h}(S_m)$ denote the estimated results, corresponding to the loss values of evaluating $f_W$ with each subset as the prompt conditioning. Stage 3: Third, compute an influence score $s_i$ for each demonstration example based on the estimated $\hat{h}$, for $i = 1, 2, \dots, n^{\text{Demo}}$. Specifically, $s_i$ can be thought of as the importance score for the $i$-th demonstration example applied to the query set. Then, select a subset of $k$ out of $n$ demonstrations via a threshold $\lambda$ on the scores.
  • Figure 2: An illustration of our approach for in-context learning of linear functions, as we vary the number of in-context examples $k$. Figure \ref{['fig_lr']}: incurs low approximation error relative to full inference, for both linear and nonlinear regression. Figure \ref{['fig_lr_selection']}: With , the selected demonstrations follow the same linear function $\beta$, achieving lower error than top-$k$ and random-$k$ selections.
  • Figure 3: Trade-off between the number of FLOPs and test error rates, measured on two datasets with DeepSeek-7B models.
  • Figure 4: Comparing our approach with top-$k$ and random-$k$ by varying $k$ on two datasets using DeepSeek-7B models. Here, $k$ varies from $0$ to $150$.

Theorems & Definitions (1)

  • lemma 1: The Johnson-Lindenstrauss Lemma