Table of Contents
Fetching ...

Decoupled-Value Attention for Prior-Data Fitted Networks: GP Inference for Physical Equations

Kaustubh Sharma, Simardeep Singh, Parikshit Pareek

TL;DR

Gaussian Process inference is computationally expensive for large or evolving datasets. The authors propose Decoupled-Value Attention (DVA), which localizes attention to the input space while streaming labels through the value stream, thereby emulating GP updates without fixed kernels. Empirical results show DVA-based PFNs dramatically reduce bias across 1D–10D and scale to 64D power-flow tasks with large speedups, approaching exact GP performance. The findings suggest DVA enables scalable, uncertainty-aware physics surrogates that are architecture-agnostic. This could enable real-time, high-dimensional simulations for complex physical systems such as power grids.

Abstract

Prior-data fitted networks (PFNs) are a promising alternative to time-consuming Gaussian process (GP) inference for creating fast surrogates of physical systems. PFN reduces the computational burden of GP-training by replacing Bayesian inference in GP with a single forward pass of a learned prediction model. However, with standard Transformer attention, PFNs show limited effectiveness on high-dimensional regression tasks. We introduce Decoupled-Value Attention (DVA)-- motivated by the GP property that the function space is fully characterized by the kernel over inputs and the predictive mean is a weighted sum of training targets. DVA computes similarities from inputs only and propagates labels solely through values. Thus, the proposed DVA mirrors the GP update while remaining kernel-free. We demonstrate that PFNs are backbone architecture invariant and the crucial factor for scaling PFNs is the attention rule rather than the architecture itself. Specifically, our results demonstrate that (a) localized attention consistently reduces out-of-sample validation loss in PFNs across different dimensional settings, with validation loss reduced by more than 50% in five- and ten-dimensional cases, and (b) the role of attention is more decisive than the choice of backbone architecture, showing that CNN, RNN and LSTM-based PFNs can perform at par with their Transformer-based counterparts. The proposed PFNs provide 64-dimensional power flow equation approximations with a mean absolute error of the order of E-03, while being over 80x faster than exact GP inference.

Decoupled-Value Attention for Prior-Data Fitted Networks: GP Inference for Physical Equations

TL;DR

Gaussian Process inference is computationally expensive for large or evolving datasets. The authors propose Decoupled-Value Attention (DVA), which localizes attention to the input space while streaming labels through the value stream, thereby emulating GP updates without fixed kernels. Empirical results show DVA-based PFNs dramatically reduce bias across 1D–10D and scale to 64D power-flow tasks with large speedups, approaching exact GP performance. The findings suggest DVA enables scalable, uncertainty-aware physics surrogates that are architecture-agnostic. This could enable real-time, high-dimensional simulations for complex physical systems such as power grids.

Abstract

Prior-data fitted networks (PFNs) are a promising alternative to time-consuming Gaussian process (GP) inference for creating fast surrogates of physical systems. PFN reduces the computational burden of GP-training by replacing Bayesian inference in GP with a single forward pass of a learned prediction model. However, with standard Transformer attention, PFNs show limited effectiveness on high-dimensional regression tasks. We introduce Decoupled-Value Attention (DVA)-- motivated by the GP property that the function space is fully characterized by the kernel over inputs and the predictive mean is a weighted sum of training targets. DVA computes similarities from inputs only and propagates labels solely through values. Thus, the proposed DVA mirrors the GP update while remaining kernel-free. We demonstrate that PFNs are backbone architecture invariant and the crucial factor for scaling PFNs is the attention rule rather than the architecture itself. Specifically, our results demonstrate that (a) localized attention consistently reduces out-of-sample validation loss in PFNs across different dimensional settings, with validation loss reduced by more than 50% in five- and ten-dimensional cases, and (b) the role of attention is more decisive than the choice of backbone architecture, showing that CNN, RNN and LSTM-based PFNs can perform at par with their Transformer-based counterparts. The proposed PFNs provide 64-dimensional power flow equation approximations with a mean absolute error of the order of E-03, while being over 80x faster than exact GP inference.

Paper Structure

This paper contains 32 sections, 2 theorems, 19 equations, 15 figures, 9 tables.

Key Result

Theorem 1

Assume the input encoder is linear, i.e. $\varphi_x(\mathbf{x}) = W_x \mathbf{x}$ and the DVA query/key maps are $Q = W_q W_x \mathbf{x}, \; K = W_k W_x \mathbf{x}.$ Let $A := (W_q W_x)^\top (W_k W_x)$. If $A$ is symmetric positive definite, and define $\|\mathbf{z}\|_A=\mathbf{z}^TA\mathbf{z}$, the

Figures (15)

  • Figure 1: Effect of Kernel in PFN Attention: Sample functions from 1D PFN training datasets (Left). Validation loss for smooth and non-smooth functions with Kernel-based Attention and DVA with Transformer (Middle) and CNN (Right).
  • Figure 2: Bias Reduction in PFN Training: Validation loss (NLL) behavior with number of training points for various PFNs (Number of training points = epochs $\times$ steps per epoch $\times$ batch-size $\times$ dataset size. Dataset size is 100 for 1D/2D, 400 for 5D and 500 for 10D PFN). Validation loss was calculated on 64 out-of-sample datasets and Transformer + VA is taken from müller2024transformersbayesianinference.
  • Figure 3: Comparison of validation loss vs. training points ($N_{train}$) for RNN and LSTM architectures with VA and DVA attentions.
  • Figure 4: In the first layer, we can clearly see that as the Euclidean distance increases, the softmax weight decreases exponentially, clearly showing that the DVA mechanism enforces localization. Localization is enforced in the first layer, and since layer 2 is the last layer of the model, which outputs the exact values, the softmax values are all minimal and closer to each other, suggesting the last layer averages the result for proper PPD approximation. Essentially, Layer 1 performs the "Local Smoothing" (gathering information from neighbors). Layer 2 performs "Feature Mixing" (processing the gathered information). Since Layer 1 has already gathered the local information into the latent vector, Layer 2 no longer needs to be spatially local; it can attend globally or uniformly to refine the prediction. This provided empirical evidence of the DVA localization theorem.
  • Figure 5: In contrast to the DVA, the standard VA Transformer assigns near-uniform attention weights across the entire input domain (flat trend lines). This confirms that without explicit localization through architecture changes (which we made in DVA), the model defaults to global averaging rather than local interpolation. Attention weights in both layers remain effectively uniform regardless of input distance, empirically validating Theorem 6.3 of Naglar 2023.
  • ...and 10 more figures

Theorems & Definitions (2)

  • Theorem 1: DVA attention weight $\propto$ Mahalanobis RBF kernel under linear embeddings
  • Theorem 2: DVA localization under nonlinear embeddings