Table of Contents
Fetching ...

GAProtoNet: A Multi-head Graph Attention-based Prototypical Network for Interpretable Text Classification

Ximing Wen, Wenjuan Tan, Rosina O. Weber

TL;DR

GAProtoNet tackles interpretability in text classification by combining a language-model encoder with a white-box prototypical network that uses multi-head graph attention to connect input embeddings to $M$ prototypes. Predictions arise from a linear combination of prototypes weighted by attention scores, yielding explanations via attention weights and prototype-text correspondences. Empirical results on diverse datasets show GAProtoNet achieves competitive or superior accuracy and F1 compared to strong black-box LMs and outperforms existing prototype-based methods, with multi-head variants offering clear gains. The work demonstrates that graph-attention can efficiently learn relational structure between inputs and prototypes, producing transparent decisions supported by prototype visualizations and case studies, though it acknowledges the need for user studies to further validate explainability in practice.

Abstract

Pretrained transformer-based Language Models (LMs) are well-known for their ability to achieve significant improvement on text classification tasks with their powerful word embeddings, but their black-box nature, which leads to a lack of interpretability, has been a major concern. In this work, we introduce GAProtoNet, a novel white-box Multi-head Graph Attention-based Prototypical Network designed to explain the decisions of text classification models built with LM encoders. In our approach, the input vector and prototypes are regarded as nodes within a graph, and we utilize multi-head graph attention to selectively construct edges between the input node and prototype nodes to learn an interpretable prototypical representation. During inference, the model makes decisions based on a linear combination of activated prototypes weighted by the attention score assigned for each prototype, allowing its choices to be transparently explained by the attention weights and the prototypes projected into the closest matching training examples. Experiments on multiple public datasets show our approach achieves superior results without sacrificing the accuracy of the original black-box LMs. We also compare with four alternative prototypical network variations and our approach achieves the best accuracy and F1 among all. Our case study and visualization of prototype clusters also demonstrate the efficiency in explaining the decisions of black-box models built with LMs.

GAProtoNet: A Multi-head Graph Attention-based Prototypical Network for Interpretable Text Classification

TL;DR

GAProtoNet tackles interpretability in text classification by combining a language-model encoder with a white-box prototypical network that uses multi-head graph attention to connect input embeddings to prototypes. Predictions arise from a linear combination of prototypes weighted by attention scores, yielding explanations via attention weights and prototype-text correspondences. Empirical results on diverse datasets show GAProtoNet achieves competitive or superior accuracy and F1 compared to strong black-box LMs and outperforms existing prototype-based methods, with multi-head variants offering clear gains. The work demonstrates that graph-attention can efficiently learn relational structure between inputs and prototypes, producing transparent decisions supported by prototype visualizations and case studies, though it acknowledges the need for user studies to further validate explainability in practice.

Abstract

Pretrained transformer-based Language Models (LMs) are well-known for their ability to achieve significant improvement on text classification tasks with their powerful word embeddings, but their black-box nature, which leads to a lack of interpretability, has been a major concern. In this work, we introduce GAProtoNet, a novel white-box Multi-head Graph Attention-based Prototypical Network designed to explain the decisions of text classification models built with LM encoders. In our approach, the input vector and prototypes are regarded as nodes within a graph, and we utilize multi-head graph attention to selectively construct edges between the input node and prototype nodes to learn an interpretable prototypical representation. During inference, the model makes decisions based on a linear combination of activated prototypes weighted by the attention score assigned for each prototype, allowing its choices to be transparently explained by the attention weights and the prototypes projected into the closest matching training examples. Experiments on multiple public datasets show our approach achieves superior results without sacrificing the accuracy of the original black-box LMs. We also compare with four alternative prototypical network variations and our approach achieves the best accuracy and F1 among all. Our case study and visualization of prototype clusters also demonstrate the efficiency in explaining the decisions of black-box models built with LMs.
Paper Structure (35 sections, 13 equations, 8 figures, 1 table)

This paper contains 35 sections, 13 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: Overview of the GAProtoNet Architecture: The GAProtoNet architecture consists of three primary components: a Text Embedding Layer using LMs, a Prototype layer, and Graph Attention. The text embeddings are linearly transformed to produce query vectors. Each attention head will construct a graph based on the attention score between the query vectors and predefined prototypes. As shown in the figure, different heads will activate prototypes in different semantic aspects, assigning different attention score to negative ones and positive ones. An interpretable prototypical representation will be formed with a linear combination of all prototypes weighted by the attention score and sent to the output layer for classification.
  • Figure 2: Prototype activation under Attention Head 3.
  • Figure 3: Using a Yelp review as input, the prototypes activated by various graph attention heads are analyzed. The review's label and prediction are both positive. For each prototype, an attention head is checked if it constructs an edge between the prototype and the input. Different aspects are highlighted in distinct colors: blue for price, yellow for service, red for wait time, and green for food quality.
  • Figure 4: Hotel
  • Figure 5: IMDb
  • ...and 3 more figures