Table of Contents
Fetching ...

ASCENT-ViT: Attention-based Scale-aware Concept Learning Framework for Enhanced Alignment in Vision Transformers

Sanchit Sinha, Guangzhi Xiong, Aidong Zhang

TL;DR

ASCENT-ViT addresses the challenge of interpretable alignment between human concepts and Vision Transformer representations by learning scale- and patch-aware features and aligning them through attention. It integrates three modules—Multi-scale Encoding (MSE), Deformable Multi-Scale Fusion (DMSF), and Concept-Representation Alignment Module (CRAM)—to produce explanations that reflect both spatial and global concepts while improving classification. Across five datasets, including CUB, AWA2, KITS, Pascal aPY, and Concept-MNIST, ASCENT-ViT achieves higher task accuracy and more accurate concept localization than model-agnostic baselines. The approach demonstrates robustness to transformations and provides a scalable, efficient explainability augmentation with only a small parameter overhead, making it practical for real-world deployment.

Abstract

As Vision Transformers (ViTs) are increasingly adopted in sensitive vision applications, there is a growing demand for improved interpretability. This has led to efforts to forward-align these models with carefully annotated abstract, human-understandable semantic entities - concepts. Concepts provide global rationales to the model predictions and can be quickly understood/intervened on by domain experts. Most current research focuses on designing model-agnostic, plug-and-play generic concept-based explainability modules that do not incorporate the inner workings of foundation models (e.g., inductive biases, scale invariance, etc.) during training. To alleviate this issue for ViTs, in this paper, we propose ASCENT-ViT, an attention-based, concept learning framework that effectively composes scale and position-aware representations from multiscale feature pyramids and ViT patch representations, respectively. Further, these representations are aligned with concept annotations through attention matrices - which incorporate spatial and global (semantic) concepts. ASCENT-ViT can be utilized as a classification head on top of standard ViT backbones for improved predictive performance and accurate and robust concept explanations as demonstrated on five datasets, including three widely used benchmarks (CUB, Pascal APY, Concept-MNIST) and 2 real-world datasets (AWA2, KITS).

ASCENT-ViT: Attention-based Scale-aware Concept Learning Framework for Enhanced Alignment in Vision Transformers

TL;DR

ASCENT-ViT addresses the challenge of interpretable alignment between human concepts and Vision Transformer representations by learning scale- and patch-aware features and aligning them through attention. It integrates three modules—Multi-scale Encoding (MSE), Deformable Multi-Scale Fusion (DMSF), and Concept-Representation Alignment Module (CRAM)—to produce explanations that reflect both spatial and global concepts while improving classification. Across five datasets, including CUB, AWA2, KITS, Pascal aPY, and Concept-MNIST, ASCENT-ViT achieves higher task accuracy and more accurate concept localization than model-agnostic baselines. The approach demonstrates robustness to transformations and provides a scalable, efficient explainability augmentation with only a small parameter overhead, making it practical for real-world deployment.

Abstract

As Vision Transformers (ViTs) are increasingly adopted in sensitive vision applications, there is a growing demand for improved interpretability. This has led to efforts to forward-align these models with carefully annotated abstract, human-understandable semantic entities - concepts. Concepts provide global rationales to the model predictions and can be quickly understood/intervened on by domain experts. Most current research focuses on designing model-agnostic, plug-and-play generic concept-based explainability modules that do not incorporate the inner workings of foundation models (e.g., inductive biases, scale invariance, etc.) during training. To alleviate this issue for ViTs, in this paper, we propose ASCENT-ViT, an attention-based, concept learning framework that effectively composes scale and position-aware representations from multiscale feature pyramids and ViT patch representations, respectively. Further, these representations are aligned with concept annotations through attention matrices - which incorporate spatial and global (semantic) concepts. ASCENT-ViT can be utilized as a classification head on top of standard ViT backbones for improved predictive performance and accurate and robust concept explanations as demonstrated on five datasets, including three widely used benchmarks (CUB, Pascal APY, Concept-MNIST) and 2 real-world datasets (AWA2, KITS).
Paper Structure (29 sections, 10 equations, 12 figures, 6 tables)

This paper contains 29 sections, 10 equations, 12 figures, 6 tables.

Figures (12)

  • Figure 1: Schematic overview of the proposed ASCENT-ViT module. Given an input image, the Multi-scale encoding (MSE) module encodes representations at various scales in c. Patch-aware representations $\mathbf{z_q}$ from the ViT are composed using the Deformable Multi-Scale Fusion (DMSF) Module. The Concept-Representation Alignment Module (CRAM) aligns the learned representations $\mathbf{z}$ with concept annotations $C$. $\hat{y}$ is the model estimation of the true task label $y$.
  • Figure 2: (a) Detailed view of the MSE module which extracts multi-scale features $\{c_i\}^S_1$ and concatenates them together into a vector $\mathbf{c}$. (b) Detailed view of the proposed DMSF module utilizing Deformable Attention operation to combine multi-scale features $\mathbf{c}$ with patch embeddings from ViT ($\mathbf{z_q}$) to output scale and patch-aware representations $\mathbf{z}$.
  • Figure 3: Detailed view of the Concept-Representation Alignment Module (CRAM). CRAM aligns scale and patch-aware representations ($\mathbf{z}$) and human-annotated concepts ($C$). The matrices $P_q$, $P_k$ and $P_{v1}$,$P_{v2}$ are projection matrices while $\mathbf{A_{Spatial}}$ and $\mathbf{A_{Global}}$ are attention matrices between concept projections and patch embedding projections. The final prediction $\hat{y}$ is the average of outputs between spatial and global attention operations.
  • Figure 4: Effect of most relevant hyperparameters - attention heads and $\psi$ in the DMSF module on performance. The dotted red line shows the baseline performance of utilizing only CRAM.
  • Figure 5: Effect of Scale: Visualized concept annotations for a correctly classified sample from KITS. As seen, ASCENT-ViT identifies small 'cyst' annotation much more accurately (Acc=91.2%) than only CRAM (Acc=88.4%) where no scale information is utilized.
  • ...and 7 more figures