Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning
Jannik Kossen, Neil Band, Clare Lyle, Aidan N. Gomez, Tom Rainforth, Yarin Gal
TL;DR
This work challenges the conventional parametric supervision paradigm by introducing Non-Parametric Transformers (NPTs) that take the entire dataset as input and explicitly learn relationships between datapoints via self-attention. NPTs combine cross-datapoint (ABD) and within-datapoint (ABA) attention with a BERT-style masking objective to reconstruct corrupted inputs, enabling end-to-end learning of when to leverage information from other points. Empirically, NPTs achieve competitive results on tabular benchmarks, demonstrate cross-datapoint lookup in semi-synthetic tasks, and show real-data reliance on datapoint interactions through corruption and attention-map analyses. The work provides evidence that modelling datapoint interactions can enhance prediction, while outlining scalability limitations and numerous directions for future research, including efficient attention approximations and broader applicability to continual and multi-task learning.
Abstract
We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.
