Table of Contents
Fetching ...

ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data

Xiangjian Jiang, Andrei Margeloiu, Nikola Simidjievski, Mateja Jamnik

TL;DR

HDLSS tabular biomedical data pose challenges for feature selection and prediction due to high dimensionality and limited samples. ProtoGate tackles this by integrating a global-to-local feature selector with a non-parametric prototype-based predictor, guarded by a disjoint training loss to avoid co-adaptation. The model uses soft global sparsity via $\ ext{\|W^{[1]}\|_1}$ and instance-specific local masks derived from Gaussian-perturbed activations, together with a differentiable $K$-NN over a prototype base and a hybrid NeuralSort-QuickSort scheme for efficient, explainable predictions. Empirical results on seven HDLSS real-world datasets and four non-HDLSS datasets show ProtoGate achieves higher accuracy with fewer selected features, maintains computational efficiency, and provides robust interpretability through prototypical explanations and feature fidelity, with an open-source implementation available.

Abstract

Tabular biomedical data poses challenges in machine learning because it is often high-dimensional and typically low-sample-size (HDLSS). Previous research has attempted to address these challenges via local feature selection, but existing approaches often fail to achieve optimal performance due to their limitation in identifying globally important features and their susceptibility to the co-adaptation problem. In this paper, we propose ProtoGate, a prototype-based neural model for feature selection on HDLSS data. ProtoGate first selects instance-wise features via adaptively balancing global and local feature selection. Furthermore, ProtoGate employs a non-parametric prototype-based prediction mechanism to tackle the co-adaptation problem, ensuring the feature selection results and predictions are consistent with underlying data clusters. We conduct comprehensive experiments to evaluate the performance and interpretability of ProtoGate on synthetic and real-world datasets. The results show that ProtoGate generally outperforms state-of-the-art methods in prediction accuracy by a clear margin while providing high-fidelity feature selection and explainable predictions. Code is available at https://github.com/SilenceX12138/ProtoGate.

ProtoGate: Prototype-based Neural Networks with Global-to-local Feature Selection for Tabular Biomedical Data

TL;DR

HDLSS tabular biomedical data pose challenges for feature selection and prediction due to high dimensionality and limited samples. ProtoGate tackles this by integrating a global-to-local feature selector with a non-parametric prototype-based predictor, guarded by a disjoint training loss to avoid co-adaptation. The model uses soft global sparsity via and instance-specific local masks derived from Gaussian-perturbed activations, together with a differentiable -NN over a prototype base and a hybrid NeuralSort-QuickSort scheme for efficient, explainable predictions. Empirical results on seven HDLSS real-world datasets and four non-HDLSS datasets show ProtoGate achieves higher accuracy with fewer selected features, maintains computational efficiency, and provides robust interpretability through prototypical explanations and feature fidelity, with an open-source implementation available.

Abstract

Tabular biomedical data poses challenges in machine learning because it is often high-dimensional and typically low-sample-size (HDLSS). Previous research has attempted to address these challenges via local feature selection, but existing approaches often fail to achieve optimal performance due to their limitation in identifying globally important features and their susceptibility to the co-adaptation problem. In this paper, we propose ProtoGate, a prototype-based neural model for feature selection on HDLSS data. ProtoGate first selects instance-wise features via adaptively balancing global and local feature selection. Furthermore, ProtoGate employs a non-parametric prototype-based prediction mechanism to tackle the co-adaptation problem, ensuring the feature selection results and predictions are consistent with underlying data clusters. We conduct comprehensive experiments to evaluate the performance and interpretability of ProtoGate on synthetic and real-world datasets. The results show that ProtoGate generally outperforms state-of-the-art methods in prediction accuracy by a clear margin while providing high-fidelity feature selection and explainable predictions. Code is available at https://github.com/SilenceX12138/ProtoGate.
Paper Structure (50 sections, 2 theorems, 15 equations, 11 figures, 23 tables)

This paper contains 50 sections, 2 theorems, 15 equations, 11 figures, 23 tables.

Key Result

Lemma 4.2

For the relaxed permutation matrix $\widehat{\mathbf{P}}$, assuming the entries of ${\bm{v}}$ are independently drawn from a distribution continuously relative to the Lebesgue measure on ${\mathbb{R}}$, then the convergence holds almost surely: where $\mathcal{U}$ denotes a discrete uniform distribution. This convergence is substantiated by "Theorem 4" in grover2018stochastic.

Figures (11)

  • Figure 1: An overview of the proposed model. ProtoGate introduces a novel disjoint in-model selection method. It balances global and local feature selection, and makes explainable prototypical predictions. In contrast to (a), ProtoGate integrates a trainable feature selector with a non-trainable predictor (i.e., no trainable parameters in the predictor), which allows for disjointly learned feature selector and predictor, thus mitigating the co-adaptation problem. In contrast to (b), ProtoGate makes predictions with the selected features, preserving their in-model explainability.
  • Figure 2: The architecture of ProtoGate.(A) Given a sample ${\bm{x}} \in {\mathbb{R}}^{D}$, the global-to-local feature selection performs soft global feature selection in the first layer of the gating network. The orange dashed lines denote sparsified weights (i.e., reduced to zero) in $\mathbf{W}^{[1]}$ under $\ell_1$-regularisation. The neural network then computes the instance-wise mask $\{s_d\}_{d=1}^{D} \in [0,1]^{D}$ with a thresholding function for local feature selection. (B) The local mask ${\bm{s}}_{\text{local}}$ is applied to the sample for local feature selection by element-wise multiplication. (C) The non-parametric prototype-based prediction further classifies ${\bm{x}} \odot {\bm{s}}_{\text{local}}$ by retrieving the $K$ nearest prototypes in base $\mathcal{B}$ via hybrid sorting. The majority class is used as the predicted label $\hat{y}$, and the exemplars (i.e., the nearest prototypes) provide prototypical explanations.
  • Figure 3: Left: Mean normalised feature selection sparsity vs. mean normalised balanced accuracy. We exclude the outliers (TabNet, L2X and INVASE) due to their suboptimal results or failed convergence. Middle: Median runtime vs. mean normalised balanced accuracy. Right: Median model size vs. mean normalised balanced accuracy. ProtoGate generally achieves higher accuracy and fewer selected features with higher computation efficiency than other local methods.
  • Figure 4: Fidelity evaluation of selected features on three synthetic datasets.Left: Mean normalised F1 score of selected features (F1$_\text{select}$) vs. mean normalised balanced accuracy (ACC$_{\text{pred}}$). Right: Rank difference between F1$_\text{select}$ and ACC$_{\text{pred}}$. A positive value (highlighted in red) indicates low-fidelity feature selection. Note that we plot short bars at zeros for visual clearance. ProtoGate has competitive trade-off between F1$_\text{select}$ and ACC$_{\text{pred}}$, and consistently non-positive rank differences, showing robustness to co-adaptation.
  • Figure 5: Normalised balanced accuracy (%) of simple models with different feature selectors. ProtoGate, with global-to-local selection, selects highly transferable features that generally improve the classification performance of KNN and SVM.
  • ...and 6 more figures

Theorems & Definitions (4)

  • Definition 4.1
  • Lemma 4.2
  • Theorem 4.3
  • proof