TReX- Reusing Vision Transformer's Attention for Efficient Xbar-based Computing
Abhishek Moitra, Abhiroop Bhattacharjee, Youngeun Kim, Priyadarshini Panda
TL;DR
This work addresses the substantial energy-delay-area bottleneck of Vision Transformers when deployed on edge devices using In-Memory Computing. It introduces TReX, an attention-reuse optimization framework that uses small transformation blocks to reuse attention outputs across selected encoder layers, guided by a hardware-aware search performed with the TReX-Sim platform. The method achieves near-iso-accuracy with significant EDAP reductions (up to about 2.3x on ImageNet-1k for DeiT-S and LV-ViT-S) and improves TOPS/W and TOPS/mm$^2$, while validating robustness through FeFET-SRAM hybrid crossbar designs and NLP results on CoLA. The work demonstrates practical impact by enabling energy-efficient, high-throughput ViTs for edge deployment and cross-domain applicability in NLP tasks.
Abstract
Due to the high computation overhead of Vision Transformers (ViTs), In-memory Computing architectures are being researched towards energy-efficient deployment in edge-computing scenarios. Prior works have proposed efficient algorithm-hardware co-design and IMC-architectural improvements to improve the energy-efficiency of IMC-implemented ViTs. However, all prior works have neglected the overhead and co-depencence of attention blocks on the accuracy-energy-delay-area of IMC-implemented ViTs. To this end, we propose TReX- an attention-reuse-driven ViT optimization framework that effectively performs attention reuse in ViT models to achieve optimal accuracy-energy-delay-area tradeoffs. TReX optimally chooses the transformer encoders for attention reuse to achieve near iso-accuracy performance while meeting the user-specified delay requirement. Based on our analysis on the Imagenet-1k dataset, we find that TReX achieves 2.3x (2.19x) EDAP reduction and 1.86x (1.79x) TOPS/mm2 improvement with ~1% accuracy drop in case of DeiT-S (LV-ViT-S) ViT models. Additionally, TReX achieves high accuracy at high EDAP reduction compared to state-of-the-art token pruning and weight sharing approaches. On NLP tasks such as CoLA, TReX leads to 2% higher non-ideal accuracy compared to baseline at 1.6x lower EDAP.
