Table of Contents
Fetching ...

On Understanding Attention-Based In-Context Learning for Categorical Data

Aaron T. Wang, William Convertino, Xiang Cheng, Ricardo Henao, Lawrence Carin

TL;DR

The paper develops an attention-based framework that implements exact multi-step functional gradient descent for in-context learning with categorical outputs. By interleaving self-attention and cross-attention, and by embedding token information, the model can emulate the GD updates on a latent function that feeds a softmax predictor. The authors provide theoretical results showing GD is a stationary point of the attention-based objective and validate the approach through synthetic data, in-context image classification, and language modeling, demonstrating competitive performance with far fewer parameters than standard Transformers. This work offers a principled GD perspective on in-context learning for discrete data and demonstrates practical benefits in data-efficient, context-driven inference. It also highlights the potential of adding a feedforward component to close gaps with full Transformer performance in language tasks.

Abstract

In-context learning based on attention models is examined for data with categorical outcomes, with inference in such models viewed from the perspective of functional gradient descent (GD). We develop a network composed of attention blocks, with each block employing a self-attention layer followed by a cross-attention layer, with associated skip connections. This model can exactly perform multi-step functional GD inference for in-context inference with categorical observations. We perform a theoretical analysis of this setup, generalizing many prior assumptions in this line of work, including the class of attention mechanisms for which it is appropriate. We demonstrate the framework empirically on synthetic data, image classification and language generation.

On Understanding Attention-Based In-Context Learning for Categorical Data

TL;DR

The paper develops an attention-based framework that implements exact multi-step functional gradient descent for in-context learning with categorical outputs. By interleaving self-attention and cross-attention, and by embedding token information, the model can emulate the GD updates on a latent function that feeds a softmax predictor. The authors provide theoretical results showing GD is a stationary point of the attention-based objective and validate the approach through synthetic data, in-context image classification, and language modeling, demonstrating competitive performance with far fewer parameters than standard Transformers. This work offers a principled GD perspective on in-context learning for discrete data and demonstrates practical benefits in data-efficient, context-driven inference. It also highlights the potential of adding a feedforward component to close gaps with full Transformer performance in language tasks.

Abstract

In-context learning based on attention models is examined for data with categorical outcomes, with inference in such models viewed from the perspective of functional gradient descent (GD). We develop a network composed of attention blocks, with each block employing a self-attention layer followed by a cross-attention layer, with associated skip connections. This model can exactly perform multi-step functional GD inference for in-context inference with categorical observations. We perform a theoretical analysis of this setup, generalizing many prior assumptions in this line of work, including the class of attention mechanisms for which it is appropriate. We demonstrate the framework empirically on synthetic data, image classification and language generation.
Paper Structure (44 sections, 77 equations, 12 figures, 6 tables)

This paper contains 44 sections, 77 equations, 12 figures, 6 tables.

Figures (12)

  • Figure 1: Summary of interleaved attention for multi-step GD with categorical observations.
  • Figure 2: For one self-attention layer, comparison of accuracy and loss of GD and Trained TF with RBF (left) and softmax (right) attention, as a function of contextual training sets. Error bars are computed from five different random initializations.
  • Figure 3: For the multi-layered model in Section \ref{['sec:MultiGD']}, training curves of accuracy (left) and negative log-likelihood (right) for the GD transformer with softmax attention. Error bars are computed from five different random initializations.
  • Figure 4: (Left) For two attention blocks for the model in Section \ref{['sec:MultiGD']}, comparison of accuracy and loss of GD and Trained TF with softmax attention, as a function of contextual training sets. (Right) Model accuracy for the ImageNet dataset, for which the attention-based model was trained once for 500 epochs, and the linear probing was trained for 500 epochs on each test context.
  • Figure 5: GPT-4o scoring of the generated story endings. Each item is graded out of a maximum score of 10. Softmax GD and Softmax GD with FF have 8K attention-weight parameters, and the Transformer has 6M attention-weight parameters.
  • ...and 7 more figures