On Understanding Attention-Based In-Context Learning for Categorical Data
Aaron T. Wang, William Convertino, Xiang Cheng, Ricardo Henao, Lawrence Carin
TL;DR
The paper develops an attention-based framework that implements exact multi-step functional gradient descent for in-context learning with categorical outputs. By interleaving self-attention and cross-attention, and by embedding token information, the model can emulate the GD updates on a latent function that feeds a softmax predictor. The authors provide theoretical results showing GD is a stationary point of the attention-based objective and validate the approach through synthetic data, in-context image classification, and language modeling, demonstrating competitive performance with far fewer parameters than standard Transformers. This work offers a principled GD perspective on in-context learning for discrete data and demonstrates practical benefits in data-efficient, context-driven inference. It also highlights the potential of adding a feedforward component to close gaps with full Transformer performance in language tasks.
Abstract
In-context learning based on attention models is examined for data with categorical outcomes, with inference in such models viewed from the perspective of functional gradient descent (GD). We develop a network composed of attention blocks, with each block employing a self-attention layer followed by a cross-attention layer, with associated skip connections. This model can exactly perform multi-step functional GD inference for in-context inference with categorical observations. We perform a theoretical analysis of this setup, generalizing many prior assumptions in this line of work, including the class of attention mechanisms for which it is appropriate. We demonstrate the framework empirically on synthetic data, image classification and language generation.
