Table of Contents
Fetching ...

CA-Stream: Attention-based pooling for interpretable image recognition

Felipe Torres, Hanwei Zhang, Ronan Sicre, Stéphane Ayache, Yannis Avrithis

TL;DR

This work addresses the interpretability gap in vision models by linking CAM-based saliency with transformer-style attention. It introduces Cross-Attention Stream (CA-Stream), a parallel pooling mechanism that replaces GAP with a learnable, attention-driven aggregation yielding a global representation $\mathbf{q}_{L+1}$. The key contributions are (i) revealing that cross-attention pooling acts as a class-agnostic CAM, (ii) designing and integrating CA-Stream with CNN backbones to enhance post-hoc explanations, and (iii) demonstrating improved interpretability metrics on ImageNet while preserving recognition accuracy. The result is a practical pathway to more transparent vision systems by embedding explanation-compatible pooling directly into inference.

Abstract

Explanations obtained from transformer-based architectures in the form of raw attention, can be seen as a class-agnostic saliency map. Additionally, attention-based pooling serves as a form of masking the in feature space. Motivated by this observation, we design an attention-based pooling mechanism intended to replace Global Average Pooling (GAP) at inference. This mechanism, called Cross-Attention Stream (CA-Stream), comprises a stream of cross attention blocks interacting with features at different network depths. CA-Stream enhances interpretability in models, while preserving recognition performance.

CA-Stream: Attention-based pooling for interpretable image recognition

TL;DR

This work addresses the interpretability gap in vision models by linking CAM-based saliency with transformer-style attention. It introduces Cross-Attention Stream (CA-Stream), a parallel pooling mechanism that replaces GAP with a learnable, attention-driven aggregation yielding a global representation . The key contributions are (i) revealing that cross-attention pooling acts as a class-agnostic CAM, (ii) designing and integrating CA-Stream with CNN backbones to enhance post-hoc explanations, and (iii) demonstrating improved interpretability metrics on ImageNet while preserving recognition accuracy. The result is a practical pathway to more transparent vision systems by embedding explanation-compatible pooling directly into inference.

Abstract

Explanations obtained from transformer-based architectures in the form of raw attention, can be seen as a class-agnostic saliency map. Additionally, attention-based pooling serves as a form of masking the in feature space. Motivated by this observation, we design an attention-based pooling mechanism intended to replace Global Average Pooling (GAP) at inference. This mechanism, called Cross-Attention Stream (CA-Stream), comprises a stream of cross attention blocks interacting with features at different network depths. CA-Stream enhances interpretability in models, while preserving recognition performance.
Paper Structure (32 sections, 11 equations, 4 figures, 6 tables)

This paper contains 32 sections, 11 equations, 4 figures, 6 tables.

Figures (4)

  • Figure 1: Cross-Attention Stream (CA-StreamCross-Attention Stream (CA-Stream) applied to ResNet-based architectures. Given a network $f$, we replace global average pooling (gap) by a learned, attention-based pooling mechanism implemented as a stream in parallel to $f$. The feature tensor $F_\ell \in \mathbb{R}^{p_\ell \times d_\ell}$ (key) obtained by stage Res-$\ell$ of $f$ interacts with a cls token (query) embedding $\mathbf{q}_\ell \in \mathbb{R}^{d_\ell}$ in block CA-$\ell$, which contains cross attention (\ref{['eq:CA']}) followed by a linear projection (\ref{['eq:qk-layer']}) to adapt to the dimension of $F_{\ell+1}$. Here, $p_\ell$ is the number of patches (spatial resolution) and $d_\ell$ the embedding dimension. The query is initialized by a learnable parameter $\mathbf{q}_0 \in \mathbb{R}^{d_0}$, while the output $\mathbf{q}_5$ of the last cross attention block is used as a global image representation into the classifier.
  • Figure 2: Comparison of saliency maps generated by different CAM-based methods, using GAP and our CA-Stream, on ImageNet images. The raw attention is the one used for pooling by CA-Stream.
  • Figure 3: Visualization of eq. (\ref{['eq:connection']}). On the left, a feature tensor $\mathbf{F} \in \mathbb{R}^{w \times h \times d}$ is multiplied by the vector ${\boldsymbol{\alpha}} \in \mathbb{R}^d$ in the channel dimension, like in $1 \times 1$ convolution, where $w \times h$ is the spatial resolution and $d$ is the number of channels. This is cross attention (CA) dosovitskiy2020image between the query ${\boldsymbol{\alpha}}$ and the key $\mathbf{F}$. On the right, a linear combination of feature maps $F^1, \dots, F^d \in \mathbb{R}^{w \times h}$ is taken with weights $\alpha_1, \dots, \alpha_d$. This is a class activation mapping (CAM) zhou2016learning with class agnostic weights. Eq. (\ref{['eq:connection']}) expresses the fact that these two quantities are the same, provided that ${\boldsymbol{\alpha}} = (\alpha_1, \dots, \alpha_d)$ and $\mathbf{F}$ is reshaped as $F = (\mathbf{f}^1 \dots \mathbf{f}^d) \in \mathbb{R}^{p \times d}$, where $p = wh$ and $\mathbf{f}^k = \operatorname{vec}(F^k) \in \mathbb{R}^{p}$ is the vectorized feature map of channel $k$.
  • Figure 4: Raw attention maps obtained from our CA-Stream on images of the MIT 67 Scenes dataset quattoni2009recognizing on classes that do not exist in ImageNet. The network sees them at inference for the first time.