Table of Contents
Fetching ...

Powerful Design of Small Vision Transformer on CIFAR10

Gent Wu

TL;DR

This work targets efficient Tiny Vision Transformers for CIFAR-10, addressing the performance gap on small datasets by investigating data augmentation, patch token initialization, low-rank attention via Multi-Latent Attention (MLA), and multi-class token strategies. It demonstrates that low-rank compression of queries incurs minimal accuracy loss and that increasing CLS-token capacity via Multi-Class Tokens significantly improves global representation and accuracy. The paper provides a practical design framework, including ablations on augmentation, initialization, and optimizers, and reports that careful choices (e.g., learnable positional embeddings, Lion optimizer) yield competitive results with reduced computational cost. The findings offer actionable guidance for building scalable, efficient Tiny ViTs on small datasets, with code available at the referenced repository.

Abstract

Vision Transformers (ViTs) have demonstrated remarkable success on large-scale datasets, but their performance on smaller datasets often falls short of convolutional neural networks (CNNs). This paper explores the design and optimization of Tiny ViTs for small datasets, using CIFAR-10 as a benchmark. We systematically evaluate the impact of data augmentation, patch token initialization, low-rank compression, and multi-class token strategies on model performance. Our experiments reveal that low-rank compression of queries in Multi-Head Latent Attention (MLA) incurs minimal performance loss, indicating redundancy in ViTs. Additionally, introducing multiple CLS tokens improves global representation capacity, boosting accuracy. These findings provide a comprehensive framework for optimizing Tiny ViTs, offering practical insights for efficient and effective designs. Code is available at https://github.com/erow/PoorViTs.

Powerful Design of Small Vision Transformer on CIFAR10

TL;DR

This work targets efficient Tiny Vision Transformers for CIFAR-10, addressing the performance gap on small datasets by investigating data augmentation, patch token initialization, low-rank attention via Multi-Latent Attention (MLA), and multi-class token strategies. It demonstrates that low-rank compression of queries incurs minimal accuracy loss and that increasing CLS-token capacity via Multi-Class Tokens significantly improves global representation and accuracy. The paper provides a practical design framework, including ablations on augmentation, initialization, and optimizers, and reports that careful choices (e.g., learnable positional embeddings, Lion optimizer) yield competitive results with reduced computational cost. The findings offer actionable guidance for building scalable, efficient Tiny ViTs on small datasets, with code available at the referenced repository.

Abstract

Vision Transformers (ViTs) have demonstrated remarkable success on large-scale datasets, but their performance on smaller datasets often falls short of convolutional neural networks (CNNs). This paper explores the design and optimization of Tiny ViTs for small datasets, using CIFAR-10 as a benchmark. We systematically evaluate the impact of data augmentation, patch token initialization, low-rank compression, and multi-class token strategies on model performance. Our experiments reveal that low-rank compression of queries in Multi-Head Latent Attention (MLA) incurs minimal performance loss, indicating redundancy in ViTs. Additionally, introducing multiple CLS tokens improves global representation capacity, boosting accuracy. These findings provide a comprehensive framework for optimizing Tiny ViTs, offering practical insights for efficient and effective designs. Code is available at https://github.com/erow/PoorViTs.
Paper Structure (20 sections, 6 equations, 6 figures, 3 tables)

This paper contains 20 sections, 6 equations, 6 figures, 3 tables.

Figures (6)

  • Figure 1: Vision transformer overview from dosovitskiy2020image. It involves (1) Tokenization: The image is divided into patches of a predetermined size. Each patch is then linearly embedded, and position embeddings are incorporated to maintain spatial information. The sequence of vectors, which represents the embedded patches along with their positional data. (2) Token transformation: tokens are subsequently input into a conventional Transformer encoder. (3) Task projection: For the purpose of classification, a common method is employed, which involves appending an additional learnable "classification token" to the sequence of vectors. This token is trained to aggregate the information from all patches and serves as the basis for the classification decision.
  • Figure 2: Overview of Attention mechanism.
  • Figure 3: Illustration of Multi-Head Attention (MHA), Grouped-Query Attention (GQA), Multi-Query Attention (MQA), and Multi-head Latent Attention (MLA) from liu2024deepseek.
  • Figure 4: Tracing one training step. One step (120.7 ms) involves forward (25.2 ms), backward (79.3 ms), optimization (12.7 ms), and other processes.
  • Figure 5: Whitening patterns in patch embedding.
  • ...and 1 more figures