Table of Contents
Fetching ...

One-Layer Transformer Provably Learns One-Nearest Neighbor In Context

Zihao Li, Yuan Cao, Cheng Gao, Yihan He, Han Liu, Jason M. Klusowski, Jianqing Fan, Mengdi Wang

TL;DR

This paper studies the capability of one-layer transformers in learning one of the most classical nonparametric estimators, the one-nearest neighbor prediction rule, and shows that, although the loss function is nonconvex when trained with gradient descent, a single softmax attention layer can successfully learn to behave like a one-nearest neighbor classifier.

Abstract

Transformers have achieved great success in recent years. Interestingly, transformers have shown particularly strong in-context learning capability -- even without fine-tuning, they are still able to solve unseen tasks well purely based on task-specific prompts. In this paper, we study the capability of one-layer transformers in learning one of the most classical nonparametric estimators, the one-nearest neighbor prediction rule. Under a theoretical framework where the prompt contains a sequence of labeled training data and unlabeled test data, we show that, although the loss function is nonconvex when trained with gradient descent, a single softmax attention layer can successfully learn to behave like a one-nearest neighbor classifier. Our result gives a concrete example of how transformers can be trained to implement nonparametric machine learning algorithms, and sheds light on the role of softmax attention in transformer models.

One-Layer Transformer Provably Learns One-Nearest Neighbor In Context

TL;DR

This paper studies the capability of one-layer transformers in learning one of the most classical nonparametric estimators, the one-nearest neighbor prediction rule, and shows that, although the loss function is nonconvex when trained with gradient descent, a single softmax attention layer can successfully learn to behave like a one-nearest neighbor classifier.

Abstract

Transformers have achieved great success in recent years. Interestingly, transformers have shown particularly strong in-context learning capability -- even without fine-tuning, they are still able to solve unseen tasks well purely based on task-specific prompts. In this paper, we study the capability of one-layer transformers in learning one of the most classical nonparametric estimators, the one-nearest neighbor prediction rule. Under a theoretical framework where the prompt contains a sequence of labeled training data and unlabeled test data, we show that, although the loss function is nonconvex when trained with gradient descent, a single softmax attention layer can successfully learn to behave like a one-nearest neighbor classifier. Our result gives a concrete example of how transformers can be trained to implement nonparametric machine learning algorithms, and sheds light on the role of softmax attention in transformer models.

Paper Structure

This paper contains 32 sections, 23 theorems, 130 equations, 3 figures.

Key Result

Theorem 1

Consider performing gradient descent of the softmax-attention transformer model $\widehat{\mathbf{y}}_{\mathbf{W}}(\mathbf{x}_{N+1})$. Suppose the initialization satisfies Assumption ass:init with $\sigma > 2(\max\{\log(Nd), - \log(1 -(N\sqrt{d})^{\frac{1}{d}}),C_d(1 - \frac{1}{2^N})\})$, where $C_

Figures (3)

  • Figure 1: Illustration of data distribution in Assumption \ref{['ass: data-dist']} on $\mathbb{S}^2$ and the corresponding ground-truth division of $\mathbb{S}^2$ generated by one-nearest neighbor. (1) In the left panel, the red and blue points correspond to the $\mathbf{x}_i$ with $\mathbf{y}_i =1$ and $-1$ for $i \in[N]$, respectively, with $N = 500$. (2) In the right panel, the color of every point on the sphere is the same as its closest neighbor in $\{\mathbf{x}_{i}\}_{i\in[N]}$. The sphere is thus split into divisions by the one-nearest-neighbor decision rule.
  • Figure 2: Heatmap and landscape of loss function of single layer transformer when learning from one-nearest neighbor. The loss is defined in Eq. \ref{['eq:loss-func']}, generated by sampling 100 training sequences according to Assumption \ref{['ass: data-dist']}, with $d = N = 4$. We parametrize $\mathbf{W}$ as ${\rm diag}\{\xi_1,\ldots, \xi_1, 0, \xi_2\}$.
  • Figure 3: Prediction error for single softmax attention layer as a function of gradient iteration number. (1) The left panel shows the convergence of loss function during the training process. (2) The right pannel shows the MSE between the trained model and a 1-NN predictor on a well-separated testing dataset under distribution shift, as we discuss in Section \ref{['sec:num-result']}. Curves and error bars in both panels are computed as twice the standard deviation based on 10 independent trials.

Theorems & Definitions (38)

  • Definition 1: One-Nearest Neighbor Predictor
  • Theorem 1: Convergence of Gradient Descent
  • Theorem 2: Resemblance to 1-NN predictor under Distribution Shift
  • Corollary 1: Classfication of Trained Transformer
  • Lemma 1: Closed-Form Gradient
  • Lemma 2: Two-Dimensional System
  • Lemma 3: Nonconvexity of Transformer Optimization
  • Lemma 4
  • Lemma 5
  • Lemma 6
  • ...and 28 more