Table of Contents
Fetching ...

An interpretable prototype parts-based neural network for medical tabular data

Jacek Karolczak, Jerzy Stefanowski

TL;DR

This work proposes a new model for tabular data, specifically tailored to medical records, that requires discretization of diagnostic result norms, and employs trainable patching over features describing a patient, to learn meaningful prototypical parts from structured data.

Abstract

The ability to interpret machine learning model decisions is critical in such domains as healthcare, where trust in model predictions is as important as their accuracy. Inspired by the development of prototype parts-based deep neural networks in computer vision, we propose a new model for tabular data, specifically tailored to medical records, that requires discretization of diagnostic result norms. Unlike the original vision models that rely on the spatial structure, our method employs trainable patching over features describing a patient, to learn meaningful prototypical parts from structured data. These parts are represented as binary or discretized feature subsets. This allows the model to express prototypes in human-readable terms, enabling alignment with clinical language and case-based reasoning. Our proposed neural network is inherently interpretable and offers interpretable concept-based predictions by comparing the patient's description to learned prototypes in the latent space of the network. In experiments, we demonstrate that the model achieves classification performance competitive to widely used baseline models on medical benchmark datasets, while also offering transparency, bridging the gap between predictive performance and interpretability in clinical decision support.

An interpretable prototype parts-based neural network for medical tabular data

TL;DR

This work proposes a new model for tabular data, specifically tailored to medical records, that requires discretization of diagnostic result norms, and employs trainable patching over features describing a patient, to learn meaningful prototypical parts from structured data.

Abstract

The ability to interpret machine learning model decisions is critical in such domains as healthcare, where trust in model predictions is as important as their accuracy. Inspired by the development of prototype parts-based deep neural networks in computer vision, we propose a new model for tabular data, specifically tailored to medical records, that requires discretization of diagnostic result norms. Unlike the original vision models that rely on the spatial structure, our method employs trainable patching over features describing a patient, to learn meaningful prototypical parts from structured data. These parts are represented as binary or discretized feature subsets. This allows the model to express prototypes in human-readable terms, enabling alignment with clinical language and case-based reasoning. Our proposed neural network is inherently interpretable and offers interpretable concept-based predictions by comparing the patient's description to learned prototypes in the latent space of the network. In experiments, we demonstrate that the model achieves classification performance competitive to widely used baseline models on medical benchmark datasets, while also offering transparency, bridging the gap between predictive performance and interpretability in clinical decision support.
Paper Structure (22 sections, 7 equations, 2 figures, 6 tables)

This paper contains 22 sections, 7 equations, 2 figures, 6 tables.

Figures (2)

  • Figure 1: Comparison between fuzzy and hard binning for a scalar input value.
  • Figure 2: Overview of the Model for Explainable Diagnosis using Interpretable Concepts (MEDIC) architecture. Input features describing the patient are first discretized into binary features. These are projected using Hadamard product ($\odot$) into a set of $p$ interpretable input parts using sparse patching masks. Parts are encoded into an embedded space, compared to $n$ learnable prototypes using $L^2$ distance ($\otimes$), and the resulting distances are pooled and passed through a classification head to class assignment probabilities.