Table of Contents
Fetching ...

Contextual Feature Selection with Conditional Stochastic Gates

Ram Dyuthi Sristi, Ofir Lindenbaum, Shira Lifshitz, Maria Lavzin, Jackie Schiller, Gal Mishne, Hadas Benisty

TL;DR

The paper addresses context-dependent feature relevance by introducing Conditional STG (c-STG), a framework where a hypernetwork maps context z to per-feature gate parameters and a differentiable Gaussian relaxation enables gradient-based learning of context-conditioned feature subsets together with a predictor. It proves theoretical equivalence to the original NP-hard contextual feature selection formulation, shows c-STG achieves lower risk than population-level STG, and extends to a weighted variant that further scales selected features by a context-dependent weight vector. Empirically, c-STG and weighted c-STG outperform several baselines across synthetic XOR tasks, MNIST rotations, and real-world domains (healthcare, housing, neuroscience), delivering accurate predictions with sparser, context-specific explanations. The work advances context-aware interpretability and generalization in high-dimensional prediction problems, while noting the need for careful hyperparameter tuning and consideration of societal impacts.

Abstract

Feature selection is a crucial tool in machine learning and is widely applied across various scientific disciplines. Traditional supervised methods generally identify a universal set of informative features for the entire population. However, feature relevance often varies with context, while the context itself may not directly affect the outcome variable. Here, we propose a novel architecture for contextual feature selection where the subset of selected features is conditioned on the value of context variables. Our new approach, Conditional Stochastic Gates (c-STG), models the importance of features using conditional Bernoulli variables whose parameters are predicted based on contextual variables. We introduce a hypernetwork that maps context variables to feature selection parameters to learn the context-dependent gates along with a prediction model. We further present a theoretical analysis of our model, indicating that it can improve performance and flexibility over population-level methods in complex feature selection settings. Finally, we conduct an extensive benchmark using simulated and real-world datasets across multiple domains demonstrating that c-STG can lead to improved feature selection capabilities while enhancing prediction accuracy and interpretability.

Contextual Feature Selection with Conditional Stochastic Gates

TL;DR

The paper addresses context-dependent feature relevance by introducing Conditional STG (c-STG), a framework where a hypernetwork maps context z to per-feature gate parameters and a differentiable Gaussian relaxation enables gradient-based learning of context-conditioned feature subsets together with a predictor. It proves theoretical equivalence to the original NP-hard contextual feature selection formulation, shows c-STG achieves lower risk than population-level STG, and extends to a weighted variant that further scales selected features by a context-dependent weight vector. Empirically, c-STG and weighted c-STG outperform several baselines across synthetic XOR tasks, MNIST rotations, and real-world domains (healthcare, housing, neuroscience), delivering accurate predictions with sparser, context-specific explanations. The work advances context-aware interpretability and generalization in high-dimensional prediction problems, while noting the need for careful hyperparameter tuning and consideration of societal impacts.

Abstract

Feature selection is a crucial tool in machine learning and is widely applied across various scientific disciplines. Traditional supervised methods generally identify a universal set of informative features for the entire population. However, feature relevance often varies with context, while the context itself may not directly affect the outcome variable. Here, we propose a novel architecture for contextual feature selection where the subset of selected features is conditioned on the value of context variables. Our new approach, Conditional Stochastic Gates (c-STG), models the importance of features using conditional Bernoulli variables whose parameters are predicted based on contextual variables. We introduce a hypernetwork that maps context variables to feature selection parameters to learn the context-dependent gates along with a prediction model. We further present a theoretical analysis of our model, indicating that it can improve performance and flexibility over population-level methods in complex feature selection settings. Finally, we conduct an extensive benchmark using simulated and real-world datasets across multiple domains demonstrating that c-STG can lead to improved feature selection capabilities while enhancing prediction accuracy and interpretability.
Paper Structure (25 sections, 10 theorems, 23 equations, 12 figures, 3 tables, 1 algorithm)

This paper contains 25 sections, 10 theorems, 23 equations, 12 figures, 3 tables, 1 algorithm.

Key Result

Theorem 1

Let $s^*({\hbox{\boldmath $z$}})$ and ${s'}^{*}({\hbox{\boldmath $z$}})$ represent the optimal feature selection functions in Eq. eq:optim and its corresponding probabilistic formulation in Eq. eq:bern_risk respectively. Then $s^*({\hbox{\boldmath $z$}})=s'^{*}({\hbox{\boldmath $z$}})$.

Figures (12)

  • Figure 1: Illustration with "rotating MNIST". We perform a binary classification between rotated versions of the digits 4 (top row) and 9 (bottom row) from the MNIST dataset. We compare the features selected by the global STG yamada2020feature model (A) and our proposed c-STG (B), which selects features conditioned on the rotation angle. Each image depicts the mean pixel values across all rotated images (A) or all images at a given rotation angle (B). Red dots indicate the features selected using STG (A) or c-STG (B) for each rotation angle. c-STG can learn to dynamically change its prediction of the most informative features given the context (rotation angle).
  • Figure 2: Contextual feature selection framework. Contextual variables $z$ (in purple) feed into the hypernetwork. The hypernetwork outputs the parameters of the gates, $\mu(z)$, which are combined with $\epsilon$ to determine if each gate is open or close $\tilde{s}_d$ (yellow) for each feature $x_d$ (blue). For weighted c-STG, the hypernetwork also outputs weight vectors (green), indicating the importance of the selected explanatory features. The selected and weighted features are fed into the prediction model, thus enhancing its ability to process feature significance in predictions.
  • Figure 3: XOR2. (A) Ground truth feature significance as a function of context, $z$. Feature gates for c-STG (B) and weighted c-STG (C).
  • Figure 4: Heart disease. Feature selection gates $\mu(z)$ for each input feature as a function of context— age and gender (left - females and center - males)— produced by weighted c-STG. The difference in the c-STG values between males and females indicates gender-specific informative features as a function of age.
  • Figure 5: Housing. Geographic significance of nine housing features according to weighted c-STG analysis ($w(z)\times \widetilde{s}(z)$). The red to blue gradient denotes positive to negative impact, respectively, with gray as neutral. Weighted significant features (A) and unselected features (B).
  • ...and 7 more figures

Theorems & Definitions (10)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • Theorem 5
  • Theorem 5
  • Theorem 5
  • Theorem 5
  • Theorem 5
  • Theorem 5