Transformer Neural Processes - Kernel Regression
Daniel Jenson, Jhonathan Navott, Mengyan Zhang, Makkunda Sharma, Elizaveta Semenova, Seth Flaxman
TL;DR
Transformer Neural Process - Kernel Regression (TNP-KR) tackles the scalability-accuracy trade-off in Neural Processes by introducing a Kernel Regression Block and kernel-based attention biases. The two attention variants, Scan Attention (SA) and Deep Kernel Attention (DKA), enable efficient inference on very large context/test sets, with SA supporting translation invariance and DKA achieving near-linear complexity. Empirical results across 1D/2D GPs, Bayesian optimization, image completion, and epidemiology show consistent gains in predictive likelihood and competitive or superior performance compared to prior NP and Transformer-based methods. The approach offers a practical, extensible path to high-resolution, scalable probabilistic modeling in diverse scientific domains.
Abstract
Neural Processes (NPs) are a rapidly evolving class of models designed to directly model the posterior predictive distribution of stochastic processes. Originally developed as a scalable alternative to Gaussian Processes (GPs), which are limited by $O(n^3)$ runtime complexity, the most accurate modern NPs can often rival GPs but still suffer from an $O(n^2)$ bottleneck due to their attention mechanism. We introduce the Transformer Neural Process - Kernel Regression (TNP-KR), a scalable NP featuring: (1) a Kernel Regression Block (KRBlock), a simple, extensible, and parameter efficient transformer block with complexity $O(n_c^2 + n_c n_t)$, where $n_c$ and $n_t$ are the number of context and test points, respectively; (2) a kernel-based attention bias; and (3) two novel attention mechanisms: scan attention (SA), a memory-efficient scan-based attention that when paired with a kernel-based bias can make TNP-KR translation invariant, and deep kernel attention (DKA), a Performer-style attention that implicitly incoporates a distance bias and further reduces complexity to $O(n_c)$. These enhancements enable both TNP-KR variants to perform inference with 100K context points on over 1M test points in under a minute on a single 24GB GPU. On benchmarks spanning meta regression, Bayesian optimization, image completion, and epidemiology, TNP-KR with DKA outperforms its Performer counterpart on nearly every benchmark, while TNP-KR with SA achieves state-of-the-art results.
