Table of Contents
Fetching ...

GFSNetwork: Differentiable Feature Selection via Gumbel-Sigmoid Relaxation

Witold Wydmański, Marek Śmieja

TL;DR

GFSNetwork addresses feature selection for high-dimensional data by learning a differentiable mask via temperature-controlled Gumbel-Sigmoid sampling. The model splits into a masking network that produces a GS-based mask and a task network that learns from masked inputs, optimizing $\mathcal{L}_{total} = \mathcal{L}_{task} + \lambda \mathcal{L}_{select}$ with an annealed temperature $\tau$ to promote sparsity while maintaining performance. It demonstrates competitive accuracy with far fewer features across classification, regression, and metagenomic benchmarks, and provides interpretable feature subsets as shown by MNIST visualizations, all with near-constant computational overhead. Limitations arise with engineered second-order feature interactions, pointing to future work in capturing feature interdependencies while preserving scalability.

Abstract

Feature selection in deep learning remains a critical challenge, particularly for high-dimensional tabular data where interpretability and computational efficiency are paramount. We present GFSNetwork, a novel neural architecture that performs differentiable feature selection through temperature-controlled Gumbel-Sigmoid sampling. Unlike traditional methods, where the user has to define the requested number of features, GFSNetwork selects it automatically during an end-to-end process. Moreover, GFSNetwork maintains constant computational overhead regardless of the number of input features. We evaluate GFSNetwork on a series of classification and regression benchmarks, where it consistently outperforms recent methods including DeepLasso, attention maps, as well as traditional feature selectors, while using significantly fewer features. Furthermore, we validate our approach on real-world metagenomic datasets, demonstrating its effectiveness in high-dimensional biological data. Concluding, our method provides a scalable solution that bridges the gap between neural network flexibility and traditional feature selection interpretability. We share our python implementation of GFSNetwork at https://github.com/wwydmanski/GFSNetwork, as well as a PyPi package (gfs_network).

GFSNetwork: Differentiable Feature Selection via Gumbel-Sigmoid Relaxation

TL;DR

GFSNetwork addresses feature selection for high-dimensional data by learning a differentiable mask via temperature-controlled Gumbel-Sigmoid sampling. The model splits into a masking network that produces a GS-based mask and a task network that learns from masked inputs, optimizing with an annealed temperature to promote sparsity while maintaining performance. It demonstrates competitive accuracy with far fewer features across classification, regression, and metagenomic benchmarks, and provides interpretable feature subsets as shown by MNIST visualizations, all with near-constant computational overhead. Limitations arise with engineered second-order feature interactions, pointing to future work in capturing feature interdependencies while preserving scalability.

Abstract

Feature selection in deep learning remains a critical challenge, particularly for high-dimensional tabular data where interpretability and computational efficiency are paramount. We present GFSNetwork, a novel neural architecture that performs differentiable feature selection through temperature-controlled Gumbel-Sigmoid sampling. Unlike traditional methods, where the user has to define the requested number of features, GFSNetwork selects it automatically during an end-to-end process. Moreover, GFSNetwork maintains constant computational overhead regardless of the number of input features. We evaluate GFSNetwork on a series of classification and regression benchmarks, where it consistently outperforms recent methods including DeepLasso, attention maps, as well as traditional feature selectors, while using significantly fewer features. Furthermore, we validate our approach on real-world metagenomic datasets, demonstrating its effectiveness in high-dimensional biological data. Concluding, our method provides a scalable solution that bridges the gap between neural network flexibility and traditional feature selection interpretability. We share our python implementation of GFSNetwork at https://github.com/wwydmanski/GFSNetwork, as well as a PyPi package (gfs_network).

Paper Structure

This paper contains 18 sections, 3 equations, 5 figures, 5 tables, 1 algorithm.

Figures (5)

  • Figure 1: Architecture of GFSNetwork. Our method consists of two parts - masking and task network. The first one, in conjunction with Gumbel noise, creates a binary mask which is then used by the second network to output either a class or real number. This way we simultaneously optimize the number of features and performance of a classifier that's based on them.
  • Figure 2: Figure 3: Feature selection analysis showing the feature space representation (left) and selection probability evolution (right). The heatmap displays the distribution of features across samples, with color intensity indicating feature values. The evolution plot tracks selection probabilities throughout training progress, highlighting distinct patterns between consistently selected features (red) that maintain high probabilities either from initialization or emerge at later stages, versus non-selected features (blue) that exhibit diminishing selection probabilities over time.
  • Figure 3: The time requirements of GFSNetwork does not substantially increase with raising number of features.
  • Figure 4: Average entropy of selected features is significantly higher than the entropy of all features, which means that GFSNetwork selected features with discriminative potential (left). Moreover, selected pixel are localized in the center region of the image (right).
  • Figure 5: Analysis of sample features (top-left) from MNIST dataset shows that entropy of selected features (F1-F3) is much higher than their non-selected counterparts (F4, F5). It confirms that GFSNetwork selects the most discriminative features.