Table of Contents
Fetching ...

Softmax $\geq$ Linear: Transformers may learn to classify in-context by kernel gradient descent

Sara Dragutinović, Andrew M. Saxe, Aaditya K. Singh

TL;DR

This work investigates in-context learning (ICL) in transformers, focusing on the learning algorithms they execute when context is provided. It extends prior theory from linear self-attention to non-linear softmax attention in a discrete classification setting with cross-entropy loss, showing that linear SA implements a one-step gradient-descent update while softmax SA performs a context-adaptive kernel gradient-descent in an RKHS, governed by learned kernel width and learning-rate parameters. Through controlled experiments with a one-layer, single-head transformer on a synthetic spherical-class task, the authors demonstrate that trained models approximate these GD-like updates, with softmax SA leveraging a context-adaptive learning rate and meta-learned kernel shape to outperform linear SA. The findings offer theoretical grounding for ICL in practical settings and highlight the kernel-shape adaptation ability of softmax attention as a key factor in context sensitivity and performance.

Abstract

The remarkable ability of transformers to learn new concepts solely by reading examples within the input prompt, termed in-context learning (ICL), is a crucial aspect of intelligent behavior. Here, we focus on understanding the learning algorithm transformers use to learn from context. Existing theoretical work, often based on simplifying assumptions, has primarily focused on linear self-attention and continuous regression tasks, finding transformers can learn in-context by gradient descent. Given that transformers are typically trained on discrete and complex tasks, we bridge the gap from this existing work to the setting of classification, with non-linear (importantly, softmax) activation. We find that transformers still learn to do gradient descent in-context, though on functionals in the kernel feature space and with a context-adaptive learning rate in the case of softmax transformer. These theoretical findings suggest a greater adaptability to context for softmax attention, which we empirically verify and study through ablations. Overall, we hope this enhances theoretical understanding of in-context learning algorithms in more realistic settings, pushes forward our intuitions and enables further theory bridging to larger models.

Softmax $\geq$ Linear: Transformers may learn to classify in-context by kernel gradient descent

TL;DR

This work investigates in-context learning (ICL) in transformers, focusing on the learning algorithms they execute when context is provided. It extends prior theory from linear self-attention to non-linear softmax attention in a discrete classification setting with cross-entropy loss, showing that linear SA implements a one-step gradient-descent update while softmax SA performs a context-adaptive kernel gradient-descent in an RKHS, governed by learned kernel width and learning-rate parameters. Through controlled experiments with a one-layer, single-head transformer on a synthetic spherical-class task, the authors demonstrate that trained models approximate these GD-like updates, with softmax SA leveraging a context-adaptive learning rate and meta-learned kernel shape to outperform linear SA. The findings offer theoretical grounding for ICL in practical settings and highlight the kernel-shape adaptation ability of softmax attention as a key factor in context sensitivity and performance.

Abstract

The remarkable ability of transformers to learn new concepts solely by reading examples within the input prompt, termed in-context learning (ICL), is a crucial aspect of intelligent behavior. Here, we focus on understanding the learning algorithm transformers use to learn from context. Existing theoretical work, often based on simplifying assumptions, has primarily focused on linear self-attention and continuous regression tasks, finding transformers can learn in-context by gradient descent. Given that transformers are typically trained on discrete and complex tasks, we bridge the gap from this existing work to the setting of classification, with non-linear (importantly, softmax) activation. We find that transformers still learn to do gradient descent in-context, though on functionals in the kernel feature space and with a context-adaptive learning rate in the case of softmax transformer. These theoretical findings suggest a greater adaptability to context for softmax attention, which we empirically verify and study through ablations. Overall, we hope this enhances theoretical understanding of in-context learning algorithms in more realistic settings, pushes forward our intuitions and enables further theory bridging to larger models.

Paper Structure

This paper contains 34 sections, 2 theorems, 18 equations, 18 figures, 1 table.

Key Result

Proposition 3.1

Linear self-attention is expressive enough to implement one step of gradient descent on cross-entropy (CE) loss in the linear classification setup, assuming we start from $W_0=\bold 0$.Note that the assumption of starting from $W_0=\bold 0$ corresponds to no prior knowledge on the classes, a realist

Figures (18)

  • Figure 1: Two example contexts from our classification task (see Section \ref{['sec:task_setup']}). For Context 1, the correct class for query is Class 1, as $\bold{x}_\text{query}$ lays between the first and the third context vector. With Context 2, we emphasize how: 1) context vectors $\bold{x}_i$ and query input $\bold{x}_\text{query}$ differ between contexts; 2) class assignment differs between contexts---if the first context vector was in Context 1, its label would be Class 1, not Class 2.
  • Figure 2: Two different contexts of our synthetic classification task, with $C=5, n=100$ and a) $d=2$, b) $d=3$. The arrows are representing class vectors, and the points context vectors; different colors correspond to different class labels.
  • Figure 3: Illustration of a transformer forward pass. In our setting, we only make use of the update of the last token---we extract the $y$-entry of it and apply softmax to get the final prediction $\hat{\bold{y}}_\text{query}$.
  • Figure 4: Similarity between the two algorithms---trained linear SA and a GD step---in the setup with $C=5, n=100, d=5$. a) Alignment metrics through transformer training. Right: Similarity per context sample of b) loss, c) probability of the correct class and d) entropy of $\hat{\bold{y}}_\text{query}$: each point represents the value on one context. Dotted line corresponds to the mean value of a metric.
  • Figure 5: Similarity between a trained softmax SA and a context-adaptive step of kernel GD, in the setup $C=5, n=100, d=5$. a) Models alignment metrics through transformer training. Right: Metrics similarity per context sample: b) loss, c) probability of the correct class and d) entropy achieved by both algorithms.
  • ...and 13 more figures

Theorems & Definitions (2)

  • Proposition 3.1
  • Proposition 3.2