Table of Contents
Fetching ...

Pruning One More Token is Enough: Leveraging Latency-Workload Non-Linearities for Vision Transformers on the Edge

Nick John Eliopoulos, Purvish Jajal, James C. Davis, Gaowen Liu, George K. Thiravathukal, Yung-Hsiang Lu

TL;DR

This paper investigates how to efficiently deploy vision transformers on edge devices for small workloads by identifying factors that affect ViT latency-workload relationships and determining token pruning schedule by leveraging non-linear latency-workload relationships.

Abstract

This paper investigates how to efficiently deploy vision transformers on edge devices for small workloads. Recent methods reduce the latency of transformer neural networks by removing or merging tokens, with small accuracy degradation. However, these methods are not designed with edge device deployment in mind: they do not leverage information about the latency-workload trends to improve efficiency. We address this shortcoming in our work. First, we identify factors that affect ViT latency-workload relationships. Second, we determine token pruning schedule by leveraging non-linear latency-workload relationships. Third, we demonstrate a training-free, token pruning method utilizing this schedule. We show other methods may increase latency by 2-30%, while we reduce latency by 9-26%. For similar latency (within 5.2% or 7ms) across devices we achieve 78.6%-84.5% ImageNet1K accuracy, while the state-of-the-art, Token Merging, achieves 45.8%-85.4%.

Pruning One More Token is Enough: Leveraging Latency-Workload Non-Linearities for Vision Transformers on the Edge

TL;DR

This paper investigates how to efficiently deploy vision transformers on edge devices for small workloads by identifying factors that affect ViT latency-workload relationships and determining token pruning schedule by leveraging non-linear latency-workload relationships.

Abstract

This paper investigates how to efficiently deploy vision transformers on edge devices for small workloads. Recent methods reduce the latency of transformer neural networks by removing or merging tokens, with small accuracy degradation. However, these methods are not designed with edge device deployment in mind: they do not leverage information about the latency-workload trends to improve efficiency. We address this shortcoming in our work. First, we identify factors that affect ViT latency-workload relationships. Second, we determine token pruning schedule by leveraging non-linear latency-workload relationships. Third, we demonstrate a training-free, token pruning method utilizing this schedule. We show other methods may increase latency by 2-30%, while we reduce latency by 9-26%. For similar latency (within 5.2% or 7ms) across devices we achieve 78.6%-84.5% ImageNet1K accuracy, while the state-of-the-art, Token Merging, achieves 45.8%-85.4%.
Paper Structure (19 sections, 2 equations, 5 figures, 10 tables, 3 algorithms)

This paper contains 19 sections, 2 equations, 5 figures, 10 tables, 3 algorithms.

Figures (5)

  • Figure 1: Forward pass latency for widely used DeiT-B ($d=768$) and DinoV2-G $(d=1536$) models across various hardware (\ref{['tab:hardware_info']}) evaluated on the ImageNet1K deng_imagenet_2009 classification dataset. These plots demonstrate the variable and non-linear relationship between workload size (as defined in \ref{['sec:number_to_prune']}) and latency, across a variety of hardware. Consequently, in many cases it is possible to achieve large latency reductions without removing too many tokens. This work shows when and how to remove tokens to take advantage of these latency non-linearities.
  • Figure 2: Latency-workload characteristics of attention operators in PyTorch torch_sdpa_2024. Flash, MemEfficient, and Math are optimized, while Vanilla is not. Median latency was measured over 100 runs for each token count. All measurements had an IQR $<$ 1$\mu$s. The two annotated latency changes of MemEff are discussed in \ref{['sec:number_to_prune']}.
  • Figure 3: Illustration of our method to decide a pruning schedule (left) and how we prune according to the schedule at inference time (right), which are discussed in the sections shown at the top of the illustration.
  • Figure 4: Illustration of accuracy-latency tradeoffs of surveyed methods with $M$ = DinoV2-G: (a) batch size=2 on AGX Orin (b) batch size=4 on A100 (c) batch size=4 on AGX Orin. Our pruning schedule and mechanism generate points that expand the pareto front. The number of tokens removed at each layer ($r$) of Top-K and ToMe is evaluated from $r=5$ to $=8$ in increments of 1.
  • Figure 5: GPU Tail Effect has less impact on large batch size (here, the AGX Orin on DeIT-B with batch size of 128).