Table of Contents
Fetching ...

Learning Task Representations from In-Context Learning

Baturay Saglam, Xinyang Hu, Zhuoran Yang, Dionysis Kalogerias, Amin Karbasi

TL;DR

This work tackles how in-context learning encodes tasks within transformers by proposing Learnable Task Vector (LTV), a per-layer, head-weighted task representation optimized via gradient descent. By treating tasks as a weighted sum of attention-head activations, LTV generalizes across text and functional regression and is validated on a new regression benchmark designed to probe cross-modality task fidelity. Empirical results show LTV preserves task information in out-of-distribution prompts and often outperforms strong baselines, with weights that reveal selective head involvement and alignment of last hidden-state distributions to an optimal ICL model. The findings advance understanding of task steering in LLMs and offer a practical, modality-agnostic approach to encoding tasks without fine-tuning. The work also highlights the importance of attention-head contributions and per-layer task vectors in capturing robust, cross-domain task representations, potentially informing future PEFT and interpretability approaches.

Abstract

Large language models (LLMs) have demonstrated remarkable proficiency in in-context learning (ICL), where models adapt to new tasks through example-based prompts without requiring parameter updates. However, understanding how tasks are internally encoded and generalized remains a challenge. To address some of the empirical and technical gaps in the literature, we introduce an automated formulation for encoding task information in ICL prompts as a function of attention heads within the transformer architecture. This approach computes a single task vector as a weighted sum of attention heads, with the weights optimized causally via gradient descent. Our findings show that existing methods fail to generalize effectively to modalities beyond text. In response, we also design a benchmark to evaluate whether a task vector can preserve task fidelity in functional regression tasks. The proposed method successfully extracts task-specific information from in-context demonstrations and excels in both text and regression tasks, demonstrating its generalizability across modalities.

Learning Task Representations from In-Context Learning

TL;DR

This work tackles how in-context learning encodes tasks within transformers by proposing Learnable Task Vector (LTV), a per-layer, head-weighted task representation optimized via gradient descent. By treating tasks as a weighted sum of attention-head activations, LTV generalizes across text and functional regression and is validated on a new regression benchmark designed to probe cross-modality task fidelity. Empirical results show LTV preserves task information in out-of-distribution prompts and often outperforms strong baselines, with weights that reveal selective head involvement and alignment of last hidden-state distributions to an optimal ICL model. The findings advance understanding of task steering in LLMs and offer a practical, modality-agnostic approach to encoding tasks without fine-tuning. The work also highlights the importance of attention-head contributions and per-layer task vectors in capturing robust, cross-domain task representations, potentially informing future PEFT and interpretability approaches.

Abstract

Large language models (LLMs) have demonstrated remarkable proficiency in in-context learning (ICL), where models adapt to new tasks through example-based prompts without requiring parameter updates. However, understanding how tasks are internally encoded and generalized remains a challenge. To address some of the empirical and technical gaps in the literature, we introduce an automated formulation for encoding task information in ICL prompts as a function of attention heads within the transformer architecture. This approach computes a single task vector as a weighted sum of attention heads, with the weights optimized causally via gradient descent. Our findings show that existing methods fail to generalize effectively to modalities beyond text. In response, we also design a benchmark to evaluate whether a task vector can preserve task fidelity in functional regression tasks. The proposed method successfully extracts task-specific information from in-context demonstrations and excels in both text and regression tasks, demonstrating its generalizability across modalities.

Paper Structure

This paper contains 87 sections, 38 equations, 15 figures, 3 tables.

Figures (15)

  • Figure 1: Squared error on the query input as a function of the number of demonstrations in the prompts, where lower error $(f(x_{\text{query}}) - M_{\theta}(x_{\text{query}} \mid p^f))^2$ indicates better ICL performance. We evaluate three different function classes: (a) linear functions $f(x_\text{query}) = w^\top x_\text{query}$, (b) sparse linear functions $f(x_\text{query}) = w_s^\top x$, and (c) 2-layer ReLU neural networks (NNs) $f(x_\text{query}) = W_2 \operatorname{ReLU}(W_1 x_\text{query})$. Results are averaged over a batch of 256 tasks randomly selected from the same function class. The shaded area represents the 95% confidence interval over the sampled prompts. The dashed line indicates the maximum sequence length used during the positional encoding training, $T_\text{train}$.
  • Figure 2: Illustration of the operation. Additional and output operations may include residual connections, normalization, feedforward, or prediction layers, depending on the architecture. LTV is added sequentially to each layer, allowing the effects of the integrated LTV to be progressively observed across subsequent layers.
  • Figure 3: Squared error on the query input, averaged over a batch of 256 tasks randomly selected from the same function class. The shaded area represents the 95% confidence interval over the sampled prompts. The dashed line indicates the number of examples the transformer was trained with, and $T_\mathbf{v}$ denotes the prompt length used in LTV training. Complete results for different $T_\mathbf{v}$ values are provided in Figures \ref{['fig:complete_results_lin']}, \ref{['fig:complete_results_sparse_lin']}, and \ref{['fig:complete_results_relu']} in Appendix \ref{['app:complete_regression_evals']}.
  • Figure 4: The diagram illustrates our pipeline for ablation studies. We start by collecting $M = 25,600$ prompts corresponding to a selected task $f$. Subsequently, the first principal component of the column space of these datasets is extracted using SVD. Finally, we report the KL divergence between the KDE-estimated distributions of these components.
  • Figure 5: Evaluation on the class of linear functions, with the transformer pre-trained on up to $T_\text{train} = 41$ examples per prompt. Results are averaged over a batch of 256 randomly selected tasks. The shaded area represents the 95% confidence interval over the sampled prompts. $T_\mathbf{v}$ denotes the prompt length used during LTV training.
  • ...and 10 more figures