Table of Contents
Fetching ...

The Contextual Lasso: Sparse Linear Models via Deep Neural Networks

Ryan Thompson, Amir Dezfouli, Robert Kohn

TL;DR

The paper tackles the challenge of combining interpretability with expressivity by introducing the contextually sparse linear model, where the coefficient vector beta(z) varies with contextual features z. It learns beta(z) via a neural network that outputs dense coefficients eta(z) and uses a novel projection layer to enforce an average ell_1 constraint across the batch, effectively sparsifying the context-dependent coefficients. Extensions include grouped explanatory features, side constraints, pathwise optimization, and a relaxed fit to reduce shrinkage bias, all implemented in a Julia package ContextualLasso. Empirical results on synthetic and real data show that the contextual lasso produces sparse, highly interpretable, context-aware predictors with competitive or superior predictive performance compared to baselines including deep nets and standard lasso-type methods.

Abstract

Sparse linear models are one of several core tools for interpretable machine learning, a field of emerging importance as predictive models permeate decision-making in many domains. Unfortunately, sparse linear models are far less flexible as functions of their input features than black-box models like deep neural networks. With this capability gap in mind, we study a not-uncommon situation where the input features dichotomize into two groups: explanatory features, which are candidates for inclusion as variables in an interpretable model, and contextual features, which select from the candidate variables and determine their effects. This dichotomy leads us to the contextual lasso, a new statistical estimator that fits a sparse linear model to the explanatory features such that the sparsity pattern and coefficients vary as a function of the contextual features. The fitting process learns this function nonparametrically via a deep neural network. To attain sparse coefficients, we train the network with a novel lasso regularizer in the form of a projection layer that maps the network's output onto the space of $\ell_1$-constrained linear models. An extensive suite of experiments on real and synthetic data suggests that the learned models, which remain highly transparent, can be sparser than the regular lasso without sacrificing the predictive power of a standard deep neural network.

The Contextual Lasso: Sparse Linear Models via Deep Neural Networks

TL;DR

The paper tackles the challenge of combining interpretability with expressivity by introducing the contextually sparse linear model, where the coefficient vector beta(z) varies with contextual features z. It learns beta(z) via a neural network that outputs dense coefficients eta(z) and uses a novel projection layer to enforce an average ell_1 constraint across the batch, effectively sparsifying the context-dependent coefficients. Extensions include grouped explanatory features, side constraints, pathwise optimization, and a relaxed fit to reduce shrinkage bias, all implemented in a Julia package ContextualLasso. Empirical results on synthetic and real data show that the contextual lasso produces sparse, highly interpretable, context-aware predictors with competitive or superior predictive performance compared to baselines including deep nets and standard lasso-type methods.

Abstract

Sparse linear models are one of several core tools for interpretable machine learning, a field of emerging importance as predictive models permeate decision-making in many domains. Unfortunately, sparse linear models are far less flexible as functions of their input features than black-box models like deep neural networks. With this capability gap in mind, we study a not-uncommon situation where the input features dichotomize into two groups: explanatory features, which are candidates for inclusion as variables in an interpretable model, and contextual features, which select from the candidate variables and determine their effects. This dichotomy leads us to the contextual lasso, a new statistical estimator that fits a sparse linear model to the explanatory features such that the sparsity pattern and coefficients vary as a function of the contextual features. The fitting process learns this function nonparametrically via a deep neural network. To attain sparse coefficients, we train the network with a novel lasso regularizer in the form of a projection layer that maps the network's output onto the space of -constrained linear models. An extensive suite of experiments on real and synthetic data suggests that the learned models, which remain highly transparent, can be sparser than the regular lasso without sacrificing the predictive power of a standard deep neural network.
Paper Structure (33 sections, 2 theorems, 16 equations, 13 figures, 4 tables, 3 algorithms)

This paper contains 33 sections, 2 theorems, 16 equations, 13 figures, 4 tables, 3 algorithms.

Key Result

Proposition C.1

Let $\bm{\eta}_1,\ldots,\bm{\eta}_n\in\mathbb{R}^p$. Define $\tilde{\bm{\eta}}_1,\ldots,\tilde{\bm{\eta}}_n$ elementwise as Then optimization problem eq:signprojection admits the same optimal solution as

Figures (13)

  • Figure 1: Fitted coefficient functions from the contextual lasso for the house pricing dataset. Colored points indicate coefficient values at different locations. Grey points indicate coefficients equal to zero.
  • Figure 2: Network architecture. The contextual features $\mathbf{z}$ pass through a series of hidden layers. The resulting dense coefficients $\eta_1,\ldots,\eta_p$ then enter a projection layer to produce sparse coefficients $\beta_1,\ldots,\beta_p$. Here, the last coefficient is gray to illustrate that it is zeroed-out by the projection layer.
  • Figure 3: Comparisons on synthetic regression data. Metrics are aggregated over 10 synthetic datasets. Solid points are averages and error bars are standard errors. Dashed horizontal lines in the middle row indicate the true sparsity level.
  • Figure 4: Explanatory feature sparsity as a function of hour of day for the estimated energy consumption model. The sparsity level varies within each hour because the other contextual features vary.
  • Figure 5: Fitted spline function from the contextual lasso for the detrended fluctuation analysis (DFA) explanatory feature. The age explanatory feature is varied while the sex feature is set to female.
  • ...and 8 more figures

Theorems & Definitions (4)

  • Proposition C.1
  • Lemma C.2
  • proof
  • proof