Table of Contents
Fetching ...

Explainable Neural Networks with Guarantees: A Sparse Estimation Approach

Antoine Ledent, Peng Liu

TL;DR

SparXnet tackles the tension between predictive power and interpretability in neural networks by introducing a sparse, explainable architecture that automatically selects a small set of input features through a softmax-based routing mechanism and learns K one-dimensional Lipschitz transformation functions. The final prediction is a linear combination of these transformed features, providing direct interpretability via feature importances and per-feature effects. The paper proves a generalization bound where the sample complexity scales with the number of selected features $K$ and the Lipschitz constants, with only a logarithmic dependence on the total number of input features $d$ and independence from the number of parameters. Empirical results on synthetic and real datasets show SparXnet achieves competitive or superior performance while maintaining much more interpretable, sparse models, demonstrating practical potential for high-stakes domains like healthcare and finance.

Abstract

Balancing predictive power and interpretability has long been a challenging research area, particularly in powerful yet complex models like neural networks, where nonlinearity obstructs direct interpretation. This paper introduces a novel approach to constructing an explainable neural network that harmonizes predictiveness and explainability. Our model, termed SparXnet, is designed as a linear combination of a sparse set of jointly learned features, each derived from a different trainable function applied to a single 1-dimensional input feature. Leveraging the ability to learn arbitrarily complex relationships, our neural network architecture enables automatic selection of a sparse set of important features, with the final prediction being a linear combination of rescaled versions of these features. We demonstrate the ability to select significant features while maintaining comparable predictive performance and direct interpretability through extensive experiments on synthetic and real-world datasets. We also provide theoretical analysis on the generalization bounds of our framework, which is favorably linear in the number of selected features and only logarithmic in the number of input features. We further lift any dependence of sample complexity on the number of parameters or the architectural details under very mild conditions. Our research paves the way for further research on sparse and explainable neural networks with guarantee.

Explainable Neural Networks with Guarantees: A Sparse Estimation Approach

TL;DR

SparXnet tackles the tension between predictive power and interpretability in neural networks by introducing a sparse, explainable architecture that automatically selects a small set of input features through a softmax-based routing mechanism and learns K one-dimensional Lipschitz transformation functions. The final prediction is a linear combination of these transformed features, providing direct interpretability via feature importances and per-feature effects. The paper proves a generalization bound where the sample complexity scales with the number of selected features and the Lipschitz constants, with only a logarithmic dependence on the total number of input features and independence from the number of parameters. Empirical results on synthetic and real datasets show SparXnet achieves competitive or superior performance while maintaining much more interpretable, sparse models, demonstrating practical potential for high-stakes domains like healthcare and finance.

Abstract

Balancing predictive power and interpretability has long been a challenging research area, particularly in powerful yet complex models like neural networks, where nonlinearity obstructs direct interpretation. This paper introduces a novel approach to constructing an explainable neural network that harmonizes predictiveness and explainability. Our model, termed SparXnet, is designed as a linear combination of a sparse set of jointly learned features, each derived from a different trainable function applied to a single 1-dimensional input feature. Leveraging the ability to learn arbitrarily complex relationships, our neural network architecture enables automatic selection of a sparse set of important features, with the final prediction being a linear combination of rescaled versions of these features. We demonstrate the ability to select significant features while maintaining comparable predictive performance and direct interpretability through extensive experiments on synthetic and real-world datasets. We also provide theoretical analysis on the generalization bounds of our framework, which is favorably linear in the number of selected features and only logarithmic in the number of input features. We further lift any dependence of sample complexity on the number of parameters or the architectural details under very mild conditions. Our research paves the way for further research on sparse and explainable neural networks with guarantee.
Paper Structure (17 sections, 11 theorems, 34 equations, 6 figures, 2 tables)

This paper contains 17 sections, 11 theorems, 34 equations, 6 figures, 2 tables.

Key Result

Theorem 1

Consider the function class $\mathcal{F}$ defined above. Suppose we are given $N$ i.i.d. samples $\{(x^i,y^i)\}_{i=1}^N$ with $y^i\in\mathbb{R}$ for all $i\leq N$ and a loss function $\ell:\mathbb{R}^2\rightarrow \mathbb{R}$ which is bounded by $B$ and has a Lipschitz constant at most $\mathcal{L}$. and We have, with probability $\geq 1-\delta$ over the draw of the training set:

Figures (6)

  • Figure 1: Schematic overview of the proposed model in the case of two selected features ($K=2$). The neural network has two distinct processing pathways, which are based on a softmax operation applied to the weights of the first hidden layer. The softmax operation serves as a soft form of "routing" mechanism, allowing the network to learn and distribute the representation of different data characteristics across the two pathways, thus achieving adaptive feature selection upon saturation (as indicated by the two solid lines in the first layer). The two fully-connected mapping functions $f_1(x_1)$ and $f_2(x_3)$ are learned automatically and linearly combined to generate the final prediction.
  • Figure 2: Visualizing the learned model predictions using 2000 noisy observations and three features, including one true feature and two random noise features. Our model can correctly identify the second feature as the true input feature (based on the learned weights of 4.84171210e-07, 9.99999523e-01 and 3.66482689e-09 in the first hidden layer) and recover the original form of the underlying data-generating function.
  • Figure 3: The dynamics between predictive accuracy and recovery rate of true features as the sparsity level varies across different models, including SparXnet, Lasso, Ridge regression, decision tree, and FCN. At high sparsity levels (fewer than four features), SparXnet has a higher test RMSE than Lasso. However, as the number of features increases, SparXnet exhibits enhanced performance, eventually surpassing Ridge regression and decision tree models, and approaching the performance of FCN. SparXnet also demonstrates superior or comparable recovery rates of true features relative to Lasso, except when eight features are selected, where potential overlapping coverage of features chosen by SparXnet slightly reduces its recovery rate.
  • Figure 4: Illustrating the model inference for two credit applicants in the credit risk dataset.
  • Figure 5: Comparison of mean test AUC by the number of pathways across datasets. The analysis highlights the variability in performance sensitivity to the number of pathways, serving as a guide for optimal feature selection in practical applications.
  • ...and 1 more figures

Theorems & Definitions (17)

  • Theorem 1
  • Corollary 2
  • Theorem 3: Covering number of Lipschitz function balls, see von2004distance, Theorem 17 page 684, see also Tikhomirov1993
  • Proposition 4
  • proof
  • Proposition 5: cf. bookhighprobbartlet98pisierLedent_21_Norm-based
  • Proposition 6
  • Proposition 7
  • proof
  • Corollary 8
  • ...and 7 more