Table of Contents
Fetching ...

Gradient-free Post-hoc Explainability Using Distillation Aided Learnable Approach

Debarpan Bhattacharya, Amir H. Poorjam, Deepak Mittal, Sriram Ganapathy

TL;DR

This paper proposes a framework, named distillation aided explainability (DAX), that attempts to generate a saliency-based explanation in a model agnostic gradient free application and extensively evaluates DAX across different modalities, in a classification setting, using a diverse set of evaluations.

Abstract

The recent advancements in artificial intelligence (AI), with the release of several large models having only query access, make a strong case for explainability of deep models in a post-hoc gradient free manner. In this paper, we propose a framework, named distillation aided explainability (DAX), that attempts to generate a saliency-based explanation in a model agnostic gradient free application. The DAX approach poses the problem of explanation in a learnable setting with a mask generation network and a distillation network. The mask generation network learns to generate the multiplier mask that finds the salient regions of the input, while the student distillation network aims to approximate the local behavior of the black-box model. We propose a joint optimization of the two networks in the DAX framework using the locally perturbed input samples, with the targets derived from input-output access to the black-box model. We extensively evaluate DAX across different modalities (image and audio), in a classification setting, using a diverse set of evaluations (intersection over union with ground truth, deletion based and subjective human evaluation based measures) and benchmark it with respect to $9$ different methods. In these evaluations, the DAX significantly outperforms the existing approaches on all modalities and evaluation metrics.

Gradient-free Post-hoc Explainability Using Distillation Aided Learnable Approach

TL;DR

This paper proposes a framework, named distillation aided explainability (DAX), that attempts to generate a saliency-based explanation in a model agnostic gradient free application and extensively evaluates DAX across different modalities, in a classification setting, using a diverse set of evaluations.

Abstract

The recent advancements in artificial intelligence (AI), with the release of several large models having only query access, make a strong case for explainability of deep models in a post-hoc gradient free manner. In this paper, we propose a framework, named distillation aided explainability (DAX), that attempts to generate a saliency-based explanation in a model agnostic gradient free application. The DAX approach poses the problem of explanation in a learnable setting with a mask generation network and a distillation network. The mask generation network learns to generate the multiplier mask that finds the salient regions of the input, while the student distillation network aims to approximate the local behavior of the black-box model. We propose a joint optimization of the two networks in the DAX framework using the locally perturbed input samples, with the targets derived from input-output access to the black-box model. We extensively evaluate DAX across different modalities (image and audio), in a classification setting, using a diverse set of evaluations (intersection over union with ground truth, deletion based and subjective human evaluation based measures) and benchmark it with respect to different methods. In these evaluations, the DAX significantly outperforms the existing approaches on all modalities and evaluation metrics.
Paper Structure (31 sections, 13 equations, 10 figures, 9 tables, 1 algorithm)

This paper contains 31 sections, 13 equations, 10 figures, 9 tables, 1 algorithm.

Figures (10)

  • Figure 1: Comparing various XAI methods and DAX for input strawberry image with the explanation heat map and the mask multiplied input.
  • Figure 2: Contrasting the local linear approximation approach LIMEribeiro2016should for explainability and DAX (this work). (a) LIME segments the input image $x$. Let, the number of segments be $S$. (b) Then, it masks off segments randomly and generates the corresponding black-box responses, $y^1, ..., y^Q$ where $Q$ is the number of perturbation samples. (c) Each of the perturbed image is represented by a binary row of size $S \times 1$ where a $1$ represents masking off the corresponding segment. (d) Using the binary matrix of size, $Q \times S$ as input, and $y^1, ..., y^Q$ as targets, the LIME fits a linear model. Following the training, the linear weights denote the explanation of the input, in terms of weights corresponding to the segment locations. (e) DAX, in contrast, operates on the image space directly and uses non-linear local approximation to generate the explanation.
  • Figure 3: Let $f_{BB}(\cdot; \theta_{BB})$ be the trained black box neural network whose decision boundary (separating two classes) is shown by black solid lines in the figure. The decision boundary has different local curvatures near local inputs $x_k$ (figure (a)) and $x_l$ (figure (b)). (a) Locally linear assumption (red dashed line) is a good approximation of the black box at $x_k$. (b) As the decision boundary is less smooth near $x_l$, a linear assumption (red dashed line) is a bad local approximation of the black-box. A mild non-linearity (blue dashed line) can significantly reduce the approximation error.
  • Figure 4: The DAX framework. In part (a), we show the two components of the model, which are highlighted in the bottom row - i) Mask generation network, and ii) Student network. In part (b), the inference step is illustrated.
  • Figure 5: The learnable explanation converges in less than $10$ epochs. The (left) panel shows the mask generated in different epochs, and (right) panel shows the training and validation loss curves.
  • ...and 5 more figures