Table of Contents
Fetching ...

Transformer Learns Optimal Variable Selection in Group-Sparse Classification

Chenyang Zhang, Xuran Meng, Yuan Cao

TL;DR

This work analyzes how a one-layer self-attention transformer, trained by gradient descent on population cross-entropy, can learn a classical group-sparse classification model where the label depends on variables from a single group. It provides a tight global convergence bound showing that, after $T^* = \Theta(D^3 \vee 1/(D^3 \epsilon^3))$ iterations, the attention concentrates on the label-relevant group with $\mathbf{S}_{j^*, j}^{(T^*)} \ge 1-\exp(-\Theta(D))$ and the value vector aligns with the ground-truth $v^*$, with $\mathbf{v}_2^{(T^*)}=0$. The paper also establishes transferability to downstream tasks sharing the same sparsity pattern, giving a generalization bound $\frac{1}{n}\sum_{i=1}^n \mathbb{P}(y^{(i)} f(\mathbf{Z}^{(i)}, \widetilde{W}^{(i)}, \widetilde{v}^{(i)}) \le 0) \le O\left(\frac{d+D}{\gamma^2 n}\log^2 n\right) + O\left(\frac{\log(1/\delta)}{n}\right)$ and a sample complexity of $\widetilde{\Omega}((d+D)/\epsilon + (1/\epsilon)\log(1/\delta))$, outperforming vectorized logistic regression in certain regimes. Empirical results on synthetic data and CIFAR-10 patches corroborate the theory, showing convergent training, interpretable attention focusing on the correct group or patch, and robust downstream performance. This work thus bridges theoretical understanding of one-layer transformer mechanisms with practical, structure-exploiting learning for grouped features.

Abstract

Transformers have demonstrated remarkable success across various applications. However, the success of transformers have not been understood in theory. In this work, we give a case study of how transformers can be trained to learn a classic statistical model with "group sparsity", where the input variables form multiple groups, and the label only depends on the variables from one of the groups. We theoretically demonstrate that, a one-layer transformer trained by gradient descent can correctly leverage the attention mechanism to select variables, disregarding irrelevant ones and focusing on those beneficial for classification. We also demonstrate that a well-pretrained one-layer transformer can be adapted to new downstream tasks to achieve good prediction accuracy with a limited number of samples. Our study sheds light on how transformers effectively learn structured data.

Transformer Learns Optimal Variable Selection in Group-Sparse Classification

TL;DR

This work analyzes how a one-layer self-attention transformer, trained by gradient descent on population cross-entropy, can learn a classical group-sparse classification model where the label depends on variables from a single group. It provides a tight global convergence bound showing that, after iterations, the attention concentrates on the label-relevant group with and the value vector aligns with the ground-truth , with . The paper also establishes transferability to downstream tasks sharing the same sparsity pattern, giving a generalization bound and a sample complexity of , outperforming vectorized logistic regression in certain regimes. Empirical results on synthetic data and CIFAR-10 patches corroborate the theory, showing convergent training, interpretable attention focusing on the correct group or patch, and robust downstream performance. This work thus bridges theoretical understanding of one-layer transformer mechanisms with practical, structure-exploiting learning for grouped features.

Abstract

Transformers have demonstrated remarkable success across various applications. However, the success of transformers have not been understood in theory. In this work, we give a case study of how transformers can be trained to learn a classic statistical model with "group sparsity", where the input variables form multiple groups, and the label only depends on the variables from one of the groups. We theoretically demonstrate that, a one-layer transformer trained by gradient descent can correctly leverage the attention mechanism to select variables, disregarding irrelevant ones and focusing on those beneficial for classification. We also demonstrate that a well-pretrained one-layer transformer can be adapted to new downstream tasks to achieve good prediction accuracy with a limited number of samples. Our study sheds light on how transformers effectively learn structured data.

Paper Structure

This paper contains 22 sections, 39 theorems, 206 equations, 6 figures.

Key Result

Theorem 3.2

For any $\epsilon > 0$, suppose that $D\geq \omega(\log^2 (1/\epsilon))$, $d\leq O(\mathrm{poly}(D))$, $\sigma_x, \eta=\Theta(1)$ with $\sigma_x\leq 1/3$ and let $T^* = \Theta(D^3\vee\frac{1}{D^3\epsilon^3})$. Under these conditions, it holds that

Figures (6)

  • Figure 1: Figures on training loss, cosine similarity and norm ratio. The first line presents the training results with a sample size of 400, 6 variable groups, and a variable dimension of 4. The second line shows the training results for a sample size of 200, with 4 variable groups and a variable dimension of 2.
  • Figure 2: Heatmap of the average attention matrix. Figure \ref{['subfig:heat1']} shows the heatmap of the attention matrix corresponding to the 6 variable groups, and Figure \ref{['subfig:heat2']} shows the heatmap of the attention matrix corresponding to the 4 variable groups.
  • Figure 3: Test accuracy in the downstream task performance with different variable group numbers and variable dimensions.
  • Figure 4: Figures on training loss, cosine similarity and attention matrix. The first line presents the training results with $j^*=30$. The second line shows the training results for $j^*=70$.
  • Figure 5: Examples of embedded images. Figure \ref{['subfigreal:example']} and Figure \ref{['subfigreal:example1']} show images labeled as “Frog” and “Airplane,” respectively, embedded at position (1,1) with a token index of $1$. Figure \ref{['subfigreal:examplediff']} and Figure \ref{['subfigreal:example1diff']} show images labeled as “Frog” and “Airplane,” respectively, embedded at position (4,4) with a token index of $25$.
  • ...and 1 more figures

Theorems & Definitions (41)

  • Definition 3.1: Group sparse inputs following Gaussian distribution
  • Theorem 3.2
  • Definition 4.1
  • Theorem 4.2
  • Lemma 5.1
  • Lemma 5.2
  • Lemma 5.3
  • Proposition B.1
  • Lemma B.2
  • Lemma B.3
  • ...and 31 more