Table of Contents
Fetching ...

AtMan: Understanding Transformer Predictions Through Memory Efficient Attention Manipulation

Björn Deiseroth, Mayukh Deb, Samuel Weinbach, Manuel Brack, Patrick Schramowski, Kristian Kersting

TL;DR

This paper tackles the challenge of explaining predictions from large autoregressive transformer models without incurring prohibitive memory costs. It introduces AtMan, a memory-efficient, modality-agnostic perturbation method that manipulates attention scores within the model to generate input relevancy maps, replacing gradient-based backpropagation. By transferring perturbations into per-token embedding-space scaling and, when needed, using cosine-neighborhood correlations to suppress redundant information, AtMan achieves competitive or superior explanations on text and image-text benchmarks while dramatically reducing memory usage. The approach scales to very large multimodal models and enables deployment-time explanations, offering practical impact for deploying explainable AI in production settings. The results demonstrate that AtMan can outperform existing gradient-based XAI methods on SQuAD and OpenImages, while maintaining memory and compute efficiency across 6B, 13B, and 30B parameter models, highlighting its significance for real-world interpretability of large transformers.

Abstract

Generative transformer models have become increasingly complex, with large numbers of parameters and the ability to process multiple input modalities. Current methods for explaining their predictions are resource-intensive. Most crucially, they require prohibitively large amounts of extra memory, since they rely on backpropagation which allocates almost twice as much GPU memory as the forward pass. This makes it difficult, if not impossible, to use them in production. We present AtMan that provides explanations of generative transformer models at almost no extra cost. Specifically, AtMan is a modality-agnostic perturbation method that manipulates the attention mechanisms of transformers to produce relevance maps for the input with respect to the output prediction. Instead of using backpropagation, AtMan applies a parallelizable token-based search method based on cosine similarity neighborhood in the embedding space. Our exhaustive experiments on text and image-text benchmarks demonstrate that AtMan outperforms current state-of-the-art gradient-based methods on several metrics while being computationally efficient. As such, AtMan is suitable for use in large model inference deployments.

AtMan: Understanding Transformer Predictions Through Memory Efficient Attention Manipulation

TL;DR

This paper tackles the challenge of explaining predictions from large autoregressive transformer models without incurring prohibitive memory costs. It introduces AtMan, a memory-efficient, modality-agnostic perturbation method that manipulates attention scores within the model to generate input relevancy maps, replacing gradient-based backpropagation. By transferring perturbations into per-token embedding-space scaling and, when needed, using cosine-neighborhood correlations to suppress redundant information, AtMan achieves competitive or superior explanations on text and image-text benchmarks while dramatically reducing memory usage. The approach scales to very large multimodal models and enables deployment-time explanations, offering practical impact for deploying explainable AI in production settings. The results demonstrate that AtMan can outperform existing gradient-based XAI methods on SQuAD and OpenImages, while maintaining memory and compute efficiency across 6B, 13B, and 30B parameter models, highlighting its significance for real-world interpretability of large transformers.

Abstract

Generative transformer models have become increasingly complex, with large numbers of parameters and the ability to process multiple input modalities. Current methods for explaining their predictions are resource-intensive. Most crucially, they require prohibitively large amounts of extra memory, since they rely on backpropagation which allocates almost twice as much GPU memory as the forward pass. This makes it difficult, if not impossible, to use them in production. We present AtMan that provides explanations of generative transformer models at almost no extra cost. Specifically, AtMan is a modality-agnostic perturbation method that manipulates the attention mechanisms of transformers to produce relevance maps for the input with respect to the output prediction. Instead of using backpropagation, AtMan applies a parallelizable token-based search method based on cosine similarity neighborhood in the embedding space. Our exhaustive experiments on text and image-text benchmarks demonstrate that AtMan outperforms current state-of-the-art gradient-based methods on several metrics while being computationally efficient. As such, AtMan is suitable for use in large model inference deployments.
Paper Structure (43 sections, 8 equations, 19 figures, 4 tables)

This paper contains 43 sections, 8 equations, 19 figures, 4 tables.

Figures (19)

  • Figure 1: (a) The proposed explainability method AtMan visualizes the most important aspects of the given image while completing the sequence, displayed above the relevance maps. The generative multi-modal model MAGMA is prompted to describe the shown image with: "$<$Image$>$ This is a painting of ". (b) The integration of AtMan into the transformer architecture. We multiply the modifier factors and the attention scores before applying the diagonal causal attention mask as depicted on the right-hand side. Red hollow boxes ( ) indicate one-values, and green ones ( ) -infinity. (Best viewed in color.)
  • Figure 2: (a)Illustration of the proposed explainability method. First, we collect the original cross-entropy score of the target tokens (1). Then we iterate and suppress one token at a time, indicated by the red box ( ), and track changes in the cross-entropy score of the target token (2). (b)Manipulation of the attention scores, highlighted in blue, steers the model's prediction into a different contextual direction. Note that we found measuring of such proper generative directions to perform better than radical token-masking, as we show later (c.f. Fig. \ref{['fig:sweep']}). (Best viewed in color.)
  • Figure 3: (a) Correlated token suppression of AtMan enhances explainability in the image domain. i) Shows an input image along with three perturbation examples ($A,B,C$). In $A$, we only suppress a single image token (blue). In $B$, the same token with its relative cosine neighborhood (yellow), and in $C$, a non-related token with its neighborhood. Below depicted are the changes in the cross-entropy loss. The original score for the target token "panda" is denoted by $c_0$ and the loss change by $\Delta$. ii) Shows the resulting explanation without Cosine Similarity (CS) and with CS. We evaluated the influence of the CS quantitatively in Fig. \ref{['fig:evalcst']}. (b) An example of the SQuAD dataset with AtMan explanations. The instance contains three questions for a given context, each with a labeled answer pointing to a fragment of the context. AtMan is used to highlight the corresponding fragments of the text responsible for the answer. It can be observed that the green example is full, the blue in part, and the yellow is not at all recovered according to the given labels. However, the yellow highlight seems at least related to the label. (Best viewed in color.)
  • Figure 4: AtMan produces less noisy and more focused explanations when prompted with multi-class weak segmentation compared to Chefer. The three shown figures are prompted to explain the target classes above and below separately. (Best viewed in color.)
  • Figure 5: AtMan scales efficiently. Performance comparison of the XAI methods AtMan and Chefer et al., on various model sizes (x-axis) executed on a single 80GB memory GPU. Current gradient-based approaches do not scale; only AtMan can be utilized on large-scale models. Solid lines refer to the GPU memory consumption in GB (left y-axis). Dashed lines refer to the runtime in seconds (right y-axis). Colors indicate experiments on varying input sequence lengths. As baseline (green) a plain forward pass with a sequence length of 1024 is measured. The lightning symbol emphasizes the non-deployability when memory resource is capped to 80GB. (Best viewed in color.)
  • ...and 14 more figures