Contextual Feature Selection with Conditional Stochastic Gates
Ram Dyuthi Sristi, Ofir Lindenbaum, Shira Lifshitz, Maria Lavzin, Jackie Schiller, Gal Mishne, Hadas Benisty
TL;DR
The paper addresses context-dependent feature relevance by introducing Conditional STG (c-STG), a framework where a hypernetwork maps context z to per-feature gate parameters and a differentiable Gaussian relaxation enables gradient-based learning of context-conditioned feature subsets together with a predictor. It proves theoretical equivalence to the original NP-hard contextual feature selection formulation, shows c-STG achieves lower risk than population-level STG, and extends to a weighted variant that further scales selected features by a context-dependent weight vector. Empirically, c-STG and weighted c-STG outperform several baselines across synthetic XOR tasks, MNIST rotations, and real-world domains (healthcare, housing, neuroscience), delivering accurate predictions with sparser, context-specific explanations. The work advances context-aware interpretability and generalization in high-dimensional prediction problems, while noting the need for careful hyperparameter tuning and consideration of societal impacts.
Abstract
Feature selection is a crucial tool in machine learning and is widely applied across various scientific disciplines. Traditional supervised methods generally identify a universal set of informative features for the entire population. However, feature relevance often varies with context, while the context itself may not directly affect the outcome variable. Here, we propose a novel architecture for contextual feature selection where the subset of selected features is conditioned on the value of context variables. Our new approach, Conditional Stochastic Gates (c-STG), models the importance of features using conditional Bernoulli variables whose parameters are predicted based on contextual variables. We introduce a hypernetwork that maps context variables to feature selection parameters to learn the context-dependent gates along with a prediction model. We further present a theoretical analysis of our model, indicating that it can improve performance and flexibility over population-level methods in complex feature selection settings. Finally, we conduct an extensive benchmark using simulated and real-world datasets across multiple domains demonstrating that c-STG can lead to improved feature selection capabilities while enhancing prediction accuracy and interpretability.
