Table of Contents
Fetching ...

Neural Additive Models: Interpretable Machine Learning with Neural Nets

Rishabh Agarwal, Levi Melnick, Nicholas Frosst, Xuezhou Zhang, Ben Lengerich, Rich Caruana, Geoffrey Hinton

TL;DR

NAMs present a differentiable, additive modeling framework where each input feature is processed by its own neural network, and predictions are formed as a sum of these univariate shape functions within a GAM form. This yields models that are both interpretable (via visualizable shape functions) and competitive in accuracy with state-of-the-art intelligible methods, while offering unique capabilities like intelligible parameter generation and multitask learning. The paper demonstrates NAMs on diverse tabular datasets (e.g., FICO, MIMIC-II, California Housing, Credit Fraud) and shows benefits from ExU units for learning jagged shapes, as well as concrete multitask gains and interpretable biases in COMPAS. It also discusses practical considerations, comparisons to GANNs and EBMs, and future directions toward higher-order interactions and broader domain applicability.

Abstract

Deep neural networks (DNNs) are powerful black-box predictors that have achieved impressive performance on a wide variety of tasks. However, their accuracy comes at the cost of intelligibility: it is usually unclear how they make their decisions. This hinders their applicability to high stakes decision-making domains such as healthcare. We propose Neural Additive Models (NAMs) which combine some of the expressivity of DNNs with the inherent intelligibility of generalized additive models. NAMs learn a linear combination of neural networks that each attend to a single input feature. These networks are trained jointly and can learn arbitrarily complex relationships between their input feature and the output. Our experiments on regression and classification datasets show that NAMs are more accurate than widely used intelligible models such as logistic regression and shallow decision trees. They perform similarly to existing state-of-the-art generalized additive models in accuracy, but are more flexible because they are based on neural nets instead of boosted trees. To demonstrate this, we show how NAMs can be used for multitask learning on synthetic data and on the COMPAS recidivism data due to their composability, and demonstrate that the differentiability of NAMs allows them to train more complex interpretable models for COVID-19.

Neural Additive Models: Interpretable Machine Learning with Neural Nets

TL;DR

NAMs present a differentiable, additive modeling framework where each input feature is processed by its own neural network, and predictions are formed as a sum of these univariate shape functions within a GAM form. This yields models that are both interpretable (via visualizable shape functions) and competitive in accuracy with state-of-the-art intelligible methods, while offering unique capabilities like intelligible parameter generation and multitask learning. The paper demonstrates NAMs on diverse tabular datasets (e.g., FICO, MIMIC-II, California Housing, Credit Fraud) and shows benefits from ExU units for learning jagged shapes, as well as concrete multitask gains and interpretable biases in COMPAS. It also discusses practical considerations, comparisons to GANNs and EBMs, and future directions toward higher-order interactions and broader domain applicability.

Abstract

Deep neural networks (DNNs) are powerful black-box predictors that have achieved impressive performance on a wide variety of tasks. However, their accuracy comes at the cost of intelligibility: it is usually unclear how they make their decisions. This hinders their applicability to high stakes decision-making domains such as healthcare. We propose Neural Additive Models (NAMs) which combine some of the expressivity of DNNs with the inherent intelligibility of generalized additive models. NAMs learn a linear combination of neural networks that each attend to a single input feature. These networks are trained jointly and can learn arbitrarily complex relationships between their input feature and the output. Our experiments on regression and classification datasets show that NAMs are more accurate than widely used intelligible models such as logistic regression and shallow decision trees. They perform similarly to existing state-of-the-art generalized additive models in accuracy, but are more flexible because they are based on neural nets instead of boosted trees. To demonstrate this, we show how NAMs can be used for multitask learning on synthetic data and on the COMPAS recidivism data due to their composability, and demonstrate that the differentiability of NAMs allows them to train more complex interpretable models for COVID-19.

Paper Structure

This paper contains 26 sections, 6 equations, 17 figures, 7 tables.

Figures (17)

  • Figure 1: NAM architecture for binary classification. Each input variable is handled by a different neural network. This results in easily interpretable yet highly accurate models.
  • Figure 2: Accurately Fitting the Toy Dataset: Training predictions learned by a single hidden layer neural network with 1024 (a) standard ReLU, and (b) ReLU-$n$ with ExU hidden units trained for 10,000 epochs on the binary classification dataset described in Section \ref{['sec:new_activation']}. We can see that the ReLU network has learned a fairly smooth function while the ExU network has learned a very jumpy function. We find that a DNN with three hidden layers also learned smooth functions (see Figure \ref{['fig:dnn_toy']}).
  • Figure 3: Regularizing ExU networks. Output of a ExU feature net trained with dropout = $0.2$ for the age feature in the MIMIC-II dataset saeed2011multiparameter. Predictions from individual subnets (as a result of dropping out hidden units) are much more jagged than the average predictions using the entire feature net. Refer to Section \ref{['sec:regularization']} for an overview of regularization approaches used in this work.
  • Figure 4: ExU vs. standard hidden units. On MIMIC-II, NAMs trained with ExU units learn jumpier graphs than with standard units while achieving a similar AUC ($\approx 0.829$). Ensembling them further improves performance ($\approx 0.830$). Note that white regions in the plots correspond to regions with low data density (typically a few points) and thus we see much higher variance in the learned shape functions. We present a detailed case study on the MIMIC-II dataset in Section \ref{['sec:mimic2']}.
  • Figure 5: Understanding individual predictions for credit scores. Feature contribution using the learned NAMs for predicting scores of two applicants in the FICO dataset fico. For a given input, each feature net in the NAM acts as a lookup table and returns a contribution term. These contributions are combined in a modular way: they are added up, and passed through a link function for prediction. the longer a person's credit history, the better it is for their credit score The high scoring applicant has a long credit history (Average Months on File), which contributes to their credit score better. On the contrary, the low scoring applicant used their credit quite frequently (Total Number of Trades) and has a large burden (Net Fraction Installment Burden), thus resulting in a low score.
  • ...and 12 more figures