Table of Contents
Fetching ...

Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions

Sanjay Kariyappa, Freddy Lécué, Saumitra Mishra, Christopher Pond, Daniele Magazzeni, Manuela Veloso

TL;DR

The paper tackles explaining decoder-only sequence classification by deriving input attributions from intermediate predictions produced along the input sequence. It introduces Progressive Inference (PI), which leverages the causal attention of decoder-only models to interpret intermediate predictions $\vec{p}_i$ as predictions on masked inputs, enabling efficient explanations. It presents SP-PI, where $\phi_i = p_i^c - p_{i-1}^c$, requiring only a single forward pass, and MP-PI, which uses multiple masked passes and Kernel SHAP with an optimized mask distribution to produce SHAP-like attributions. Empirical results on GPT-2 and Llama-2 across several text-classification tasks show that SP-PI and MP-PI outperform prior XAI methods in attribution quality, with MP-PI offering the strongest gains and near-SHAP performance at higher sample efficiency.

Abstract

This paper proposes Progressive Inference - a framework to compute input attributions to explain the predictions of decoder-only sequence classification models. Our work is based on the insight that the classification head of a decoder-only Transformer model can be used to make intermediate predictions by evaluating them at different points in the input sequence. Due to the causal attention mechanism, these intermediate predictions only depend on the tokens seen before the inference point, allowing us to obtain the model's prediction on a masked input sub-sequence, with negligible computational overheads. We develop two methods to provide sub-sequence level attributions using this insight. First, we propose Single Pass-Progressive Inference (SP-PI), which computes attributions by taking the difference between consecutive intermediate predictions. Second, we exploit a connection with Kernel SHAP to develop Multi Pass-Progressive Inference (MP-PI). MP-PI uses intermediate predictions from multiple masked versions of the input to compute higher quality attributions. Our studies on a diverse set of models trained on text classification tasks show that SP-PI and MP-PI provide significantly better attributions compared to prior work.

Progressive Inference: Explaining Decoder-Only Sequence Classification Models Using Intermediate Predictions

TL;DR

The paper tackles explaining decoder-only sequence classification by deriving input attributions from intermediate predictions produced along the input sequence. It introduces Progressive Inference (PI), which leverages the causal attention of decoder-only models to interpret intermediate predictions as predictions on masked inputs, enabling efficient explanations. It presents SP-PI, where , requiring only a single forward pass, and MP-PI, which uses multiple masked passes and Kernel SHAP with an optimized mask distribution to produce SHAP-like attributions. Empirical results on GPT-2 and Llama-2 across several text-classification tasks show that SP-PI and MP-PI outperform prior XAI methods in attribution quality, with MP-PI offering the strongest gains and near-SHAP performance at higher sample efficiency.

Abstract

This paper proposes Progressive Inference - a framework to compute input attributions to explain the predictions of decoder-only sequence classification models. Our work is based on the insight that the classification head of a decoder-only Transformer model can be used to make intermediate predictions by evaluating them at different points in the input sequence. Due to the causal attention mechanism, these intermediate predictions only depend on the tokens seen before the inference point, allowing us to obtain the model's prediction on a masked input sub-sequence, with negligible computational overheads. We develop two methods to provide sub-sequence level attributions using this insight. First, we propose Single Pass-Progressive Inference (SP-PI), which computes attributions by taking the difference between consecutive intermediate predictions. Second, we exploit a connection with Kernel SHAP to develop Multi Pass-Progressive Inference (MP-PI). MP-PI uses intermediate predictions from multiple masked versions of the input to compute higher quality attributions. Our studies on a diverse set of models trained on text classification tasks show that SP-PI and MP-PI provide significantly better attributions compared to prior work.
Paper Structure (28 sections, 16 equations, 9 figures, 6 tables, 1 algorithm)

This paper contains 28 sections, 16 equations, 9 figures, 6 tables, 1 algorithm.

Figures (9)

  • Figure 1: Input tokens are fed to the decoder-only models to produce intermediate predictions. Progressive inference (PI) uses these predictions to produce attributions over input tokens/words/sentences. While Single-Pass PI uses the intermediate predictions produced by the original input tokens, multi-pass PI collects multiple sets of intermediate predictions with different masked versions of the input to compute the attribution.
  • Figure 2: Comparing the attributions produced by MP-PI with prior works on a misclassified movie review from the IMDB dataset. Only MP-PI manages to correctly identify negative sentences.
  • Figure 3: SP-PI uses the original input $\vec{x}$ to produce intermediate predictions $\{\vec{p_i}\}$. The PI framework treats these intermediate predictions as approximations of the model's prediction on the corresponding masked versions of the inputs: $\vec{p_i}\approx f(\vec{x'_i})$. SP-PI takes the difference in the intermediate predictions to compute feature-level attributions $\{\phi_i\}$.
  • Figure 4: MP-PI runs progressive inference multiple times with different masked versions of the input. It starts by sampling a binary mask $\vec{z}'$ to create a masked input $\vec{x}'$. PI interprets the intermediate predictions $\{\vec{p}_i\}$ generated from $\vec{x}'$ as predictions of the model on different perturbed versions of the input $\{\vec{x}'_i=h_{\vec{x}}(\vec{z}'_i)\}$. The set of (coalition, prediction) pairs $(S_i, \vec{p}_i)$ are filtered to remove repeated coalitions and added to the dataset $D$. Finally, we use Kernel SHAP on $D$ to produce the input attributions $\{\phi_i\}$.
  • Figure 5: Distribution of cosine similarities between Kernel SHAP and MP-PI attributions. For most datasets, we see a high cosine similarity, indicating that the attributions produced by MP-PI indeed approximates SHAP values.
  • ...and 4 more figures

Theorems & Definitions (3)

  • proof : Proof.
  • proof : Proof.
  • proof : Proof.