Table of Contents
Fetching ...

Mixture of Linear Models Co-supervised by Deep Neural Networks

Beomseok Seo, Lin Lin, Jia Li

TL;DR

A new method to estimate a mixture of linear models (MLM) for regression or classification that is relatively easy to interpret that allows us to tradeoff interpretability and accuracy.

Abstract

Deep neural network (DNN) models have achieved phenomenal success for applications in many domains, ranging from academic research in science and engineering to industry and business. The modeling power of DNN is believed to have come from the complexity and over-parameterization of the model, which on the other hand has been criticized for the lack of interpretation. Although certainly not true for every application, in some applications, especially in economics, social science, healthcare industry, and administrative decision making, scientists or practitioners are resistant to use predictions made by a black-box system for multiple reasons. One reason is that a major purpose of a study can be to make discoveries based upon the prediction function, e.g., to reveal the relationships between measurements. Another reason can be that the training dataset is not large enough to make researchers feel completely sure about a purely data-driven result. Being able to examine and interpret the prediction function will enable researchers to connect the result with existing knowledge or gain insights about new directions to explore. Although classic statistical models are much more explainable, their accuracy often falls considerably below DNN. In this paper, we propose an approach to fill the gap between relatively simple explainable models and DNN such that we can more flexibly tune the trade-off between interpretability and accuracy. Our main idea is a mixture of discriminative models that is trained with the guidance from a DNN. Although mixtures of discriminative models have been studied before, our way of generating the mixture is quite different.

Mixture of Linear Models Co-supervised by Deep Neural Networks

TL;DR

A new method to estimate a mixture of linear models (MLM) for regression or classification that is relatively easy to interpret that allows us to tradeoff interpretability and accuracy.

Abstract

Deep neural network (DNN) models have achieved phenomenal success for applications in many domains, ranging from academic research in science and engineering to industry and business. The modeling power of DNN is believed to have come from the complexity and over-parameterization of the model, which on the other hand has been criticized for the lack of interpretation. Although certainly not true for every application, in some applications, especially in economics, social science, healthcare industry, and administrative decision making, scientists or practitioners are resistant to use predictions made by a black-box system for multiple reasons. One reason is that a major purpose of a study can be to make discoveries based upon the prediction function, e.g., to reveal the relationships between measurements. Another reason can be that the training dataset is not large enough to make researchers feel completely sure about a purely data-driven result. Being able to examine and interpret the prediction function will enable researchers to connect the result with existing knowledge or gain insights about new directions to explore. Although classic statistical models are much more explainable, their accuracy often falls considerably below DNN. In this paper, we propose an approach to fill the gap between relatively simple explainable models and DNN such that we can more flexibly tune the trade-off between interpretability and accuracy. Our main idea is a mixture of discriminative models that is trained with the guidance from a DNN. Although mixtures of discriminative models have been studied before, our way of generating the mixture is quite different.

Paper Structure

This paper contains 20 sections, 19 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: A schematic plot showing the steps of creating MLM.
  • Figure 2: (a-c) Scatter plots on the two explaining variables, $X_1$ and $X_2$, with ground truth groups or convex hulls of estimated clusters for each model. (d-k) Scatter plots on the predicted values, $\hat{Y}$, and an explaining variable, $X_1$, for each model.
  • Figure 3: Estimated regression coefficients (blue dots) with "naive" confidence intervals (purple lines) for four selected variables in SKCM. The X-axis indicates the value of each regression coefficient. The linear effects of age, sex, whether the patient has radiation treatment adjuvant, and whether the patient has the history of other malignancy, vary by EPICs.
  • Figure 4: Impact of hyper-parameters, $K_l$ and $\widetilde{J}$, on MLM for bike sharing data. (a) RMSE of MLM-cell at different $K_l$'s (the number of cells at layer $l$). (b) RMSE of MLM-EPIC at different $\widetilde{J}$ (the number of EPICs). (c) Histogram of the sizes of the $150$ EPICs. (d) Kernel density estimate (KDE) plot of the maximum weights assigned to an MLM-EPIC or MOE local expert model, computed at all the training points.
  • Figure 5: The linear mixture model regression coefficients $\{\mathbf{\hat{\beta}}_j|\mathbf{x}\in \mathcal{P}_j\}_{j=1}^{\widetilde{J}}$ for the top $5$ largest EPICs for bike sharing data.
  • ...and 3 more figures