Table of Contents
Fetching ...

Transformers Trained via Gradient Descent Can Provably Learn a Class of Teacher Models

Chenyang Zhang, Qingyue Zhao, Quanquan Gu, Yuan Cao

Abstract

Transformers have achieved great success across a wide range of applications, yet the theoretical foundations underlying their success remain largely unexplored. To demystify the strong capacities of transformers applied to versatile scenarios and tasks, we theoretically investigate utilizing transformers as students to learn from a class of teacher models. Specifically, the teacher models covered in our analysis include convolution layers with average pooling, graph convolution layers, and various classic statistical learning models, including a variant of sparse token selection models [Sanford et al., 2023, Wang et al., 2024] and group-sparse linear predictors [Zhang et al., 2025]. When learning from this class of teacher models, we prove that one-layer transformers with simplified "position-only'' attention can successfully recover all parameter blocks of the teacher models, thus achieving the optimal population loss. Building upon the efficient mimicry of trained transformers towards teacher models, we further demonstrate that they can generalize well to a broad class of out-of-distribution data under mild assumptions. The key in our analysis is to identify a fundamental bilinear structure shared by various learning tasks, which enables us to establish unified learning guarantees for these tasks when treating them as teachers for transformers.

Transformers Trained via Gradient Descent Can Provably Learn a Class of Teacher Models

Abstract

Transformers have achieved great success across a wide range of applications, yet the theoretical foundations underlying their success remain largely unexplored. To demystify the strong capacities of transformers applied to versatile scenarios and tasks, we theoretically investigate utilizing transformers as students to learn from a class of teacher models. Specifically, the teacher models covered in our analysis include convolution layers with average pooling, graph convolution layers, and various classic statistical learning models, including a variant of sparse token selection models [Sanford et al., 2023, Wang et al., 2024] and group-sparse linear predictors [Zhang et al., 2025]. When learning from this class of teacher models, we prove that one-layer transformers with simplified "position-only'' attention can successfully recover all parameter blocks of the teacher models, thus achieving the optimal population loss. Building upon the efficient mimicry of trained transformers towards teacher models, we further demonstrate that they can generalize well to a broad class of out-of-distribution data under mild assumptions. The key in our analysis is to identify a fundamental bilinear structure shared by various learning tasks, which enables us to establish unified learning guarantees for these tasks when treating them as teachers for transformers.
Paper Structure (21 sections, 33 theorems, 226 equations, 9 figures, 1 table)

This paper contains 21 sections, 33 theorems, 226 equations, 9 figures, 1 table.

Key Result

Theorem 3.1

Suppose that $D\geq \Omega(\mathrm{poly}(M, K))$, $\eta\leq\mathcal{O}(M^{-1}D^{-5/2})$. Under these conditions, there exists $T^* = \Theta(\frac{KD^2}{\eta \|\mathbf{V}^*\|_F^2})$, such that for all $T\geq T^*$, the following results hold.

Figures (9)

  • Figure 1: Visualization of parameter matrices for the transformer $\widetilde{\text{TF}}$ in \ref{['eq:entire_tf']}, obtained after training to learn the teacher model $f^*$ and achieving loss convergence. The formal illustration of the loss function and training algorithm is provided in the next section.
  • Figure 2: Excess training loss, excess OOD test loss (both in log-log scales), and cosine similarity between the value matrix $\mathbf{W}_V$ of one layer transformer \ref{['def:position_only_tf']}, and ground truth value matrix $\mathbf{V}^*$. These results are presented for six experimental sets, which originate from four distinct tasks.
  • Figure 3: Heatmap of attention score matrix $\mathbf{S}^{(T)}$ when the training loss converges. The results are presented for six different experimental sets, indicated by the captions of sub-figures.
  • Figure 4: Training loss and cosine similarity between the value matrix $\mathbf{W}_V$ of the one-layer transformer \ref{['def:position_only_tf']}, and convolution kernel matrix $\mathbf{V}^*$ of the pre-trained teacher CNN.
  • Figure 5: Heatmap of the ground truth softmax scores of average pooling, Heatmap of the attention scores $\mathbf{S}^{(T)}$ of trained one-layer transformer when loss converges, and an image example in MNIST.
  • ...and 4 more figures

Theorems & Definitions (39)

  • Example 2.1: Single convolutional layer with average pooling
  • Example 2.2: Single graph convolution layer on a regular graph
  • Example 2.3: Sparse token selection model sanford2023representationalwangtransformers24
  • Remark 2.4
  • Example 2.5: Group sparse linear predictors zhang2025transformer
  • Theorem 3.1
  • Theorem 3.2
  • Lemma D.1
  • Lemma D.2
  • Lemma D.3
  • ...and 29 more