Table of Contents
Fetching ...

A Theoretical Understanding of Shallow Vision Transformers: Learning, Generalization, and Sample Complexity

Hongkang Li, Meng Wang, Sijia Liu, Pin-yu Chen

TL;DR

It is shown that a proper token sparsification can improve the test performance by removing label-irrelevant and/or noisy tokens, including spurious correlations, and it is proved that a training process using stochastic gradient descent (SGD) leads to a sparse attention map.

Abstract

Vision Transformers (ViTs) with self-attention modules have recently achieved great empirical success in many vision tasks. Due to non-convex interactions across layers, however, theoretical learning and generalization analysis is mostly elusive. Based on a data model characterizing both label-relevant and label-irrelevant tokens, this paper provides the first theoretical analysis of training a shallow ViT, i.e., one self-attention layer followed by a two-layer perceptron, for a classification task. We characterize the sample complexity to achieve a zero generalization error. Our sample complexity bound is positively correlated with the inverse of the fraction of label-relevant tokens, the token noise level, and the initial model error. We also prove that a training process using stochastic gradient descent (SGD) leads to a sparse attention map, which is a formal verification of the general intuition about the success of attention. Moreover, this paper indicates that a proper token sparsification can improve the test performance by removing label-irrelevant and/or noisy tokens, including spurious correlations. Empirical experiments on synthetic data and CIFAR-10 dataset justify our theoretical results and generalize to deeper ViTs.

A Theoretical Understanding of Shallow Vision Transformers: Learning, Generalization, and Sample Complexity

TL;DR

It is shown that a proper token sparsification can improve the test performance by removing label-irrelevant and/or noisy tokens, including spurious correlations, and it is proved that a training process using stochastic gradient descent (SGD) leads to a sparse attention map.

Abstract

Vision Transformers (ViTs) with self-attention modules have recently achieved great empirical success in many vision tasks. Due to non-convex interactions across layers, however, theoretical learning and generalization analysis is mostly elusive. Based on a data model characterizing both label-relevant and label-irrelevant tokens, this paper provides the first theoretical analysis of training a shallow ViT, i.e., one self-attention layer followed by a two-layer perceptron, for a classification task. We characterize the sample complexity to achieve a zero generalization error. Our sample complexity bound is positively correlated with the inverse of the fraction of label-relevant tokens, the token noise level, and the initial model error. We also prove that a training process using stochastic gradient descent (SGD) leads to a sparse attention map, which is a formal verification of the general intuition about the success of attention. Moreover, this paper indicates that a proper token sparsification can improve the test performance by removing label-irrelevant and/or noisy tokens, including spurious correlations. Empirical experiments on synthetic data and CIFAR-10 dataset justify our theoretical results and generalize to deeper ViTs.
Paper Structure (3 theorems, 3 equations, 4 figures)

This paper contains 3 theorems, 3 equations, 4 figures.

Key Result

Theorem 1

Given a sufficient large model and $\alpha_*$, $\alpha_\#$ with for some $c\in(0,1/(2e))$, and large enough sizes of mini-batch and the set of sampled tokens for each data, zero generalization error is achieved with a sample complexity $N$ and a number of iterations $T$:

Figures (4)

  • Figure 1: The impact of $\alpha_*$ on the sample complexity for (a) ViT and (b) CNN.
  • Figure 2: (a) Concentration of attention weights (b) Impact of token sparsification on testing loss.
  • Figure : Transformer-based foundation models
  • Figure : Vision Transformer DBKW21

Theorems & Definitions (3)

  • Theorem 1
  • Proposition 1
  • Proposition 2