Table of Contents
Fetching ...

Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

Jannik Kossen, Neil Band, Clare Lyle, Aidan N. Gomez, Tom Rainforth, Yarin Gal

TL;DR

This work challenges the conventional parametric supervision paradigm by introducing Non-Parametric Transformers (NPTs) that take the entire dataset as input and explicitly learn relationships between datapoints via self-attention. NPTs combine cross-datapoint (ABD) and within-datapoint (ABA) attention with a BERT-style masking objective to reconstruct corrupted inputs, enabling end-to-end learning of when to leverage information from other points. Empirically, NPTs achieve competitive results on tabular benchmarks, demonstrate cross-datapoint lookup in semi-synthetic tasks, and show real-data reliance on datapoint interactions through corruption and attention-map analyses. The work provides evidence that modelling datapoint interactions can enhance prediction, while outlining scalability limitations and numerous directions for future research, including efficient attention approximations and broader applicability to continual and multi-task learning.

Abstract

We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.

Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

TL;DR

This work challenges the conventional parametric supervision paradigm by introducing Non-Parametric Transformers (NPTs) that take the entire dataset as input and explicitly learn relationships between datapoints via self-attention. NPTs combine cross-datapoint (ABD) and within-datapoint (ABA) attention with a BERT-style masking objective to reconstruct corrupted inputs, enabling end-to-end learning of when to leverage information from other points. Empirically, NPTs achieve competitive results on tabular benchmarks, demonstrate cross-datapoint lookup in semi-synthetic tasks, and show real-data reliance on datapoint interactions through corruption and attention-map analyses. The work provides evidence that modelling datapoint interactions can enhance prediction, while outlining scalability limitations and numerous directions for future research, including efficient attention approximations and broader applicability to continual and multi-task learning.

Abstract

We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.

Paper Structure

This paper contains 79 sections, 5 theorems, 14 equations, 7 figures, 25 tables, 2 algorithms.

Key Result

Lemma 1

Any function of the form $f(X_1, \dots, X_n) = (g(X_1), \dots, g(X_n))$ for some $g$ is row-equivariant. These functions are denoted as 'row-wise operations', as they consist of the same function applied to each of the rows of the input.

Figures (7)

  • Figure 1: NPTs learn direct interactions between datapoints. (a) Input data: predict masked target entry [?] for datapoint $\bm{X}_i$. (b) Notation from §\ref{['sec:npt']}. (c) Parametric models predict only from the features of the given input. (d) NPTs predict by modeling relationships between all points in the dataset.
  • Figure 2: Overview of the Non-Parametric Transformer. (a) The input dataset and mask matrix are stacked and (b) linearly embedded for all datapoints independently. NPT then applies (c) Attention Between Datapoints (ABD, §\ref{['subsec:abd']}) across all $n$ samples of hidden dimension $h=d \cdot e$. (d) Attention Between Attributes (ABA, §\ref{['subsec:aba']}) then attends between the attributes for each datapoint independently. We repeat steps (c) and (d) and obtain a final prediction from a separate linear projection (not shown).
  • Figure 3: Demonstrating NPT's ability to predict from Attention Between Datapoints (ABD). (a) We append to the original data with masked targets [?] a copy of the same data with all masked values revealed, such that perfect prediction via lookup is possible. (b) Attention weights indicate that the ideal lookup behavior is learned by NPT. Shown are actual values learned by NPT at head 0.0 and depth 4.0 for the first 3.0 datapoints. (c) NPT predictions closely match the ideal values. (d) Additionally, we intervene on the values of individual targets, (e) finding that NPT predictions adjust accordingly.
  • Figure 4: Fig. 4: Attention weights.
  • Figure B.1: Visualizations of NPT attention maps for Attention Between Datapoints (ABD) for the semi-synthetic experiment at all model depths, a selection of heads, and a single batch of input data. Evidently, not all attention maps need to perform a "lookup" for the model to solve the task. In fact, some heads appear to learn almost query-independent behavior (e.g., heads 0, 1, and 2 at depth 0).
  • ...and 2 more figures

Theorems & Definitions (12)

  • Definition 1
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • Lemma 4
  • proof
  • Lemma 5
  • ...and 2 more