Table of Contents
Fetching ...

TDHook: A Lightweight Framework for Interpretability

Yoann Poupart

TL;DR

TDHook introduces a lightweight, generic interpretability framework for PyTorch designed to handle complex, multi-input/multi-output models across CV, NLP, and DRL. Built on tensordict with a HookedModel and a flexible get-set API, it provides 25+ ready-to-use methods spanning attribution, latent manipulation, and weights-based analyses, while emphasizing composability and minimal dependencies. Benchmarking against existing frameworks shows competitive performance and a smaller footprint, enabling scalable, end-to-end interpretability pipelines; concrete use cases illustrate multi-step pipelines in CV, NLP, and DRL. The work advances practical interpretability by offering a TensorDict-centric, extensible platform that lowers the barrier to building composed explainability workflows in diverse domains.

Abstract

Interpretability of Deep Neural Networks (DNNs) is a growing field driven by the study of vision and language models. Yet, some use cases, like image captioning, or domains like Deep Reinforcement Learning (DRL), require complex modelling, with multiple inputs and outputs or use composable and separated networks. As a consequence, they rarely fit natively into the API of popular interpretability frameworks. We thus present TDHook, an open-source, lightweight, generic interpretability framework based on $\texttt{tensordict}$ and applicable to any $\texttt{torch}$ model. It focuses on handling complex composed models which can be trained for Computer Vision, Natural Language Processing, Reinforcement Learning or any other domain. This library features ready-to-use methods for attribution, probing and a flexible get-set API for interventions, and is aiming to bridge the gap between these method classes to make modern interpretability pipelines more accessible. TDHook is designed with minimal dependencies, requiring roughly half as much disk space as $\texttt{transformer_lens}$, and, in our controlled benchmark, achieves up to a $\times$2 speed-up over $\texttt{captum}$ when running integrated gradients for multi-target pipelines on both CPU and GPU. In addition, to value our work, we showcase concrete use cases of our library with composed interpretability pipelines in Computer Vision (CV) and Natural Language Processing (NLP), as well as with complex models in DRL.

TDHook: A Lightweight Framework for Interpretability

TL;DR

TDHook introduces a lightweight, generic interpretability framework for PyTorch designed to handle complex, multi-input/multi-output models across CV, NLP, and DRL. Built on tensordict with a HookedModel and a flexible get-set API, it provides 25+ ready-to-use methods spanning attribution, latent manipulation, and weights-based analyses, while emphasizing composability and minimal dependencies. Benchmarking against existing frameworks shows competitive performance and a smaller footprint, enabling scalable, end-to-end interpretability pipelines; concrete use cases illustrate multi-step pipelines in CV, NLP, and DRL. The work advances practical interpretability by offering a TensorDict-centric, extensible platform that lowers the barrier to building composed explainability workflows in diverse domains.

Abstract

Interpretability of Deep Neural Networks (DNNs) is a growing field driven by the study of vision and language models. Yet, some use cases, like image captioning, or domains like Deep Reinforcement Learning (DRL), require complex modelling, with multiple inputs and outputs or use composable and separated networks. As a consequence, they rarely fit natively into the API of popular interpretability frameworks. We thus present TDHook, an open-source, lightweight, generic interpretability framework based on and applicable to any model. It focuses on handling complex composed models which can be trained for Computer Vision, Natural Language Processing, Reinforcement Learning or any other domain. This library features ready-to-use methods for attribution, probing and a flexible get-set API for interventions, and is aiming to bridge the gap between these method classes to make modern interpretability pipelines more accessible. TDHook is designed with minimal dependencies, requiring roughly half as much disk space as , and, in our controlled benchmark, achieves up to a 2 speed-up over when running integrated gradients for multi-target pipelines on both CPU and GPU. In addition, to value our work, we showcase concrete use cases of our library with composed interpretability pipelines in Computer Vision (CV) and Natural Language Processing (NLP), as well as with complex models in DRL.

Paper Structure

This paper contains 49 sections, 19 figures, 2 tables.

Figures (19)

  • Figure 1: Schematic view of the TDHook framework architecture. The target to interpret is a torch module or a TensorDictModule wrapped in a HookedModel. This object exposes different APIs, like the get-set API described in Section \ref{['sec:get_set_api']}, while its forward can be modified to fit the needs of different interpretability techniques, see an example in Section \ref{['sec:ready_to_use_methods']}. The artifacts produced and consumed by the models and methods are materialised as TensorDict objects. More details about the design principles are given in Section \ref{['sec:design_principles']}, while interactions between the main components of the framework are given in Appendix \ref{['app:code']}.
  • Figure 2: Schematic view of the code execution for the Saliency method in TDHook. (1) First, we define a context factory that will prepare a Module or TensorDictModule into a HookedModel instance. (2) Then, we run this modified model inside the context with TensorDict inputs. (3) Finally, we can retrieve the attributions from the output TensorDict.
  • Figure 3: Schematic view of the code execution for an intervention in TDHook. (1) After explicitly defining a HookedModel instance, we enter a run context with the given inputs. (2-4) We query the run instance to define our intervention scheme which registers the required hooks on the model. This get-set API returns cache proxies that can be retrieved upon model execution. (5-7) At the context exit, we execute the model, which triggers the registered hooks and populates the cache. (8) Finally, we can retrieve any intermediate state by resolving the proxies previously defined.
  • Figure 4: Relative performance of interpretability frameworks on their "Get Started" task compared against tdhook implementation (lower is better). The colours follow a logarithmic scale centred on the tdhook relative performance. More technical details and results are given in Appendix \ref{['app:benchmarking']}, and the individual task results are provided in supplementary material.
  • Figure 5: Bundle size comparison across different interpretability frameworks by measuring the memory size (a) and the inode count (b). For each measurement, we compare with only installing torch; this baseline space could also be reduced by choosing a lighter version of torch, e.g. torch-cpu.
  • ...and 14 more figures