Table of Contents
Fetching ...

Captum: A unified and generic model interpretability library for PyTorch

Narine Kokhlikyan, Vivek Miglani, Miguel Martin, Edward Wang, Bilal Alsallakh, Jonathan Reynolds, Alexander Melnikov, Natalia Kliushkina, Carlos Araya, Siqi Yan, Orion Reblitz-Richardson

TL;DR

Captum addresses the need for a unified, PyTorch-friendly toolset for model interpretability by providing gradient- and perturbation-based attribution algorithms across modalities. It introduces a scalable architecture with memory-efficient computation, a unified API, and evaluation metrics (infidelity and maximum sensitivity), plus the Captum Insights visualization tool. The library supports multi-modal inputs and can scale across GPUs, making it suitable for research and production. The work emphasizes extensibility and easy adoption across domains beyond computer vision.

Abstract

In this paper we introduce a novel, unified, open-source model interpretability library for PyTorch [12]. The library contains generic implementations of a number of gradient and perturbation-based attribution algorithms, also known as feature, neuron and layer importance algorithms, as well as a set of evaluation metrics for these algorithms. It can be used for both classification and non-classification models including graph-structured models built on Neural Networks (NN). In this paper we give a high-level overview of supported attribution algorithms and show how to perform memory-efficient and scalable computations. We emphasize that the three main characteristics of the library are multimodality, extensibility and ease of use. Multimodality supports different modality of inputs such as image, text, audio or video. Extensibility allows adding new algorithms and features. The library is also designed for easy understanding and use. Besides, we also introduce an interactive visualization tool called Captum Insights that is built on top of Captum library and allows sample-based model debugging and visualization using feature importance metrics.

Captum: A unified and generic model interpretability library for PyTorch

TL;DR

Captum addresses the need for a unified, PyTorch-friendly toolset for model interpretability by providing gradient- and perturbation-based attribution algorithms across modalities. It introduces a scalable architecture with memory-efficient computation, a unified API, and evaluation metrics (infidelity and maximum sensitivity), plus the Captum Insights visualization tool. The library supports multi-modal inputs and can scale across GPUs, making it suitable for research and production. The work emphasizes extensibility and easy adoption across domains beyond computer vision.

Abstract

In this paper we introduce a novel, unified, open-source model interpretability library for PyTorch [12]. The library contains generic implementations of a number of gradient and perturbation-based attribution algorithms, also known as feature, neuron and layer importance algorithms, as well as a set of evaluation metrics for these algorithms. It can be used for both classification and non-classification models including graph-structured models built on Neural Networks (NN). In this paper we give a high-level overview of supported attribution algorithms and show how to perform memory-efficient and scalable computations. We emphasize that the three main characteristics of the library are multimodality, extensibility and ease of use. Multimodality supports different modality of inputs such as image, text, audio or video. Extensibility allows adding new algorithms and features. The library is also designed for easy understanding and use. Besides, we also introduce an interactive visualization tool called Captum Insights that is built on top of Captum library and allows sample-based model debugging and visualization using feature importance metrics.

Paper Structure

This paper contains 9 sections, 2 equations, 5 figures, 1 table.

Figures (5)

  • Figure 1: An overview of all three types of attribution variants with example code snippets. The first variant depicted on the far-left side of the diagram represents primary attribution. The algorithms that belong to this group attribute the outputs of the model to its inputs. The middle one allows to attribute internal neurons to the inputs of the model and the one on the right-most side allows to attribute the outputs of the models to all neurons in a hidden layer.
  • Figure 2: An overview of all the attribution algorithms in Captum. The algorithms grouped on the left side of the diagram are the primary and neuron attribution algorithms. The ones on the right side of the diagram are layer attribution variants. Besides that we can also recognize color-coding of orange for gradient, green for perturbation and blue for algorithms that are neither perturbation nor gradient-based.
  • Figure 3: Visualizing salient tokens computed by integrated gradients that contribute to the predicted class using a binary classification model trained on IMDB dataset. Green means that those tokens pull towards and red that they pull away from the predicted class. The intensity of the color signifies the magnitude of the signal.
  • Figure 4: Visualizing normalized attribution scores and weights for all ten neurons in the last linear layer of a simple four MLP model trained on Boston house prices dataset.
  • Figure 5: Captum Insights interactive visualization tool after applying integrated gradients to the Visual Question Answering multi-modal model. The tool also visualizes aggregated attribution magnitudes of each modality.