Table of Contents
Fetching ...

Supervised learning pays attention

Erin Craig, Robert Tibshirani

TL;DR

The paper tackles heterogeneity in tabular data by introducing supervised attention to weight training examples for each test point, producing personalized, interpretable local models anchored to a global baseline. It formulates a practical two-step approach: estimate supervised similarity via random forest proximity, then fit a weighted local model and blend with the global model; it specializes this to the attention lasso and demonstrates strong predictive gains on real and simulated data. Key contributions include the attention lasso with interpretable per-point coefficients clustered via protoclust, extensions to time series and spatial data, and a method for adapting pretrained tree models to distributional drift without refitting. The work also connects attention concepts to kernel methods and local regression, providing both theoretical and empirical support for reduced error under mixture-of-models and heterogeneous data settings.

Abstract

In-context learning with attention enables large neural networks to make context-specific predictions by selectively focusing on relevant examples. Here, we adapt this idea to supervised learning procedures such as lasso regression and gradient boosting, for tabular data. Our goals are to (1) flexibly fit personalized models for each prediction point and (2) retain model simplicity and interpretability. Our method fits a local model for each test observation by weighting the training data according to attention, a supervised similarity measure that emphasizes features and interactions that are predictive of the outcome. Attention weighting allows the method to adapt to heterogeneous data in a data-driven way, without requiring cluster or similarity pre-specification. Further, our approach is uniquely interpretable: for each test observation, we identify which features are most predictive and which training observations are most relevant. We then show how to use attention weighting for time series and spatial data, and we present a method for adapting pretrained tree-based models to distributional shift using attention-weighted residual corrections. Across real and simulated datasets, attention weighting improves predictive performance while preserving interpretability, and theory shows that attention-weighting linear models attain lower mean squared error than the standard linear model under mixture-of-models data-generating processes with known subgroup structure.

Supervised learning pays attention

TL;DR

The paper tackles heterogeneity in tabular data by introducing supervised attention to weight training examples for each test point, producing personalized, interpretable local models anchored to a global baseline. It formulates a practical two-step approach: estimate supervised similarity via random forest proximity, then fit a weighted local model and blend with the global model; it specializes this to the attention lasso and demonstrates strong predictive gains on real and simulated data. Key contributions include the attention lasso with interpretable per-point coefficients clustered via protoclust, extensions to time series and spatial data, and a method for adapting pretrained tree models to distributional drift without refitting. The work also connects attention concepts to kernel methods and local regression, providing both theoretical and empirical support for reduced error under mixture-of-models and heterogeneous data settings.

Abstract

In-context learning with attention enables large neural networks to make context-specific predictions by selectively focusing on relevant examples. Here, we adapt this idea to supervised learning procedures such as lasso regression and gradient boosting, for tabular data. Our goals are to (1) flexibly fit personalized models for each prediction point and (2) retain model simplicity and interpretability. Our method fits a local model for each test observation by weighting the training data according to attention, a supervised similarity measure that emphasizes features and interactions that are predictive of the outcome. Attention weighting allows the method to adapt to heterogeneous data in a data-driven way, without requiring cluster or similarity pre-specification. Further, our approach is uniquely interpretable: for each test observation, we identify which features are most predictive and which training observations are most relevant. We then show how to use attention weighting for time series and spatial data, and we present a method for adapting pretrained tree-based models to distributional shift using attention-weighted residual corrections. Across real and simulated datasets, attention weighting improves predictive performance while preserving interpretability, and theory shows that attention-weighting linear models attain lower mean squared error than the standard linear model under mixture-of-models data-generating processes with known subgroup structure.

Paper Structure

This paper contains 29 sections, 6 theorems, 63 equations, 9 figures, 5 tables, 4 algorithms.

Key Result

Lemma A.1

Define the mean squared error (MSE) over the mixture distribution as where $(\bm{x}, Z) \sim P(X, Z)$ with $y = \bm{x}^\top \bm{\beta}_Z + \varepsilon$. Then: The unpenalized minimizer is given by:

Figures (9)

  • Figure 1: Toy example of a dataset with two features, $x_1$ and $x_2$, where the true model depends on the values of the features.
  • Figure 2: An outline of supervised learning with attention. First, we fit a random forest to estimate similarity between observations; second, a baseline model (for example, lasso or boosting) is fit to the training data. Then for each test observation $\textbf{x}^*$, we estimate attention weights using the random forest similarities between $\textbf{x}^*$ and each of the training observations; we use these weights to fit an attention-weighted model specific to $\textbf{x}^*$. Finally, we blend the weighted model with the baseline model to make our prediction at $\textbf{x}^*$. More details are given in Algorithm \ref{['alg:attention_super']}
  • Figure 3: Results described in Section \ref{['sec:examples:real']} and summarized in Table \ref{['tab:results_22datasets']}. Across 50 train/test splits for each dataset, attention lasso has strong performance relative to lasso, XGBoost, LightGBM, random forest and KNN. In each plot, the vertical line at $x = 0\%$ indicates no change relative to the lasso, and larger values indicate better performance (lower PSE) than the lasso.
  • Figure 4: Clustered coefficients and performance for the Auto MPG (top row) Stock Portfolio Performance (middle row) and Facebook Metrics datasets (bottom row). Models were trained using a random 50% of data, and performance is reported using the remainder. Attention lasso coefficient clustering reveals patterns in the data that may be useful for characterizing data heterogeneity.
  • Figure 5: Simulation results described in Section \ref{['sec:examples:sims']} and summarized in Table \ref{['tab:sim_results']}. Across a variety of simulations, attention lasso typically (1) matches or improves on the lasso and (2) is competitive with more complex models when they perform well. The vertical line at $x = 0\%$ indicates no performance difference from the lasso, and values to the right indicate better performance.
  • ...and 4 more figures

Theorems & Definitions (23)

  • Remark 2.1
  • Remark 2.2
  • Remark 2.3
  • Remark 2.4
  • Remark 7.1
  • Remark 8.1
  • Remark 8.2
  • Lemma A.1: Minimizer of mean squared error over the full population
  • proof
  • Theorem A.2: Irreducible Bias of Standard Lasso
  • ...and 13 more