Table of Contents
Fetching ...

Batch-CAM: Introduction to better reasoning in convolutional deep learning models

Giacomo Ignesti, Davide Moroni, Massimo Martinelli

TL;DR

The paper addresses the opacity of deep learning models in computer vision and proposes Batch-CAM, a training paradigm that fuses a batch-wise Grad-CAM explanation with a prototypical reconstruction loss to steer learning toward evidence-relevant features. The method introduces two losses—Prototype Loss and Batch-CAM Prototype Loss—that enforce attention alignment with class prototypes, computed from data rather than annotations. Empirical results on MNIST and Fashion-MNIST across SimpleCNN, ResNet-18, and ConvNeXt-V2-Tiny show consistent gains in classification accuracy and reconstruction quality, while enabling more coherent saliency maps and efficient batch-wise Grad-CAM computation. This approach advances trustworthy AI by integrating explainability into the training objective, with potential for extension to more complex domains and richer prototype representations.

Abstract

Understanding the inner workings of deep learning models is crucial for advancing artificial intelligence, particularly in high-stakes fields such as healthcare, where accurate explanations are as vital as precision. This paper introduces Batch-CAM, a novel training paradigm that fuses a batch implementation of the Grad-CAM algorithm with a prototypical reconstruction loss. This combination guides the model to focus on salient image features, thereby enhancing its performance across classification tasks. Our results demonstrate that Batch-CAM achieves a simultaneous improvement in accuracy and image reconstruction quality while reducing training and inference times. By ensuring models learn from evidence-relevant information,this approach makes a relevant contribution to building more transparent, explainable, and trustworthy AI systems.

Batch-CAM: Introduction to better reasoning in convolutional deep learning models

TL;DR

The paper addresses the opacity of deep learning models in computer vision and proposes Batch-CAM, a training paradigm that fuses a batch-wise Grad-CAM explanation with a prototypical reconstruction loss to steer learning toward evidence-relevant features. The method introduces two losses—Prototype Loss and Batch-CAM Prototype Loss—that enforce attention alignment with class prototypes, computed from data rather than annotations. Empirical results on MNIST and Fashion-MNIST across SimpleCNN, ResNet-18, and ConvNeXt-V2-Tiny show consistent gains in classification accuracy and reconstruction quality, while enabling more coherent saliency maps and efficient batch-wise Grad-CAM computation. This approach advances trustworthy AI by integrating explainability into the training objective, with potential for extension to more complex domains and richer prototype representations.

Abstract

Understanding the inner workings of deep learning models is crucial for advancing artificial intelligence, particularly in high-stakes fields such as healthcare, where accurate explanations are as vital as precision. This paper introduces Batch-CAM, a novel training paradigm that fuses a batch implementation of the Grad-CAM algorithm with a prototypical reconstruction loss. This combination guides the model to focus on salient image features, thereby enhancing its performance across classification tasks. Our results demonstrate that Batch-CAM achieves a simultaneous improvement in accuracy and image reconstruction quality while reducing training and inference times. By ensuring models learn from evidence-relevant information,this approach makes a relevant contribution to building more transparent, explainable, and trustworthy AI systems.

Paper Structure

This paper contains 13 sections, 2 equations, 7 figures, 3 tables, 2 algorithms.

Figures (7)

  • Figure 1: Proposed new model training protocol with both the prototype loss, and the Batch-CAM prototype loss
  • Figure 2: Average MNIST reconstruction prototypes per class in case of 23x23 images on architecture such as ResNet and ConvNeXt-V2
  • Figure 3: Average MNIST reconstruction prototypes per class in case of 23x23 images on architecture CNN with classical cross-entropy loss
  • Figure 4: Average MNIST reconstruction prototypes per class in case of 23x23 images on architecture CNN with the Batch-CAM Prototype Loss
  • Figure 5: Average reconstruction prototypes per class in case of MNIST 112x112 images on ConvNeXt-V2 with the Batch-CAM Prototype Loss (L1)
  • ...and 2 more figures