Table of Contents
Fetching ...

Efficient Autoregressive Inference for Transformer Probabilistic Models

Conor Hassan, Nasrulloh Loka, Cen-You Li, Daolang Huang, Paul E. Chang, Yang Yang, Francesco Silvestrin, Samuel Kaski, Luigi Acerbi

TL;DR

The paper addresses the challenge of obtaining coherent joint distributions from transformer-based probabilistic models conditioned on a context set. It introduces a causal autoregressive buffer that decouples one-time context encoding from subsequent autoregressive updates, enabling efficient batched joint sampling and one-pass likelihood evaluation while preserving marginal predictive quality. The authors formalize a four-rule attention scheme, provide a unified training strategy, and demonstrate up to 20× speedups across synthetic, EEG, multisensory, and tabular tasks with minimal cost to accuracy. This approach combines the efficiency of autoregressive generation with the representational strength of set-based conditioning, making joint prediction practical for transformer probabilistic models in diverse domains.

Abstract

Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.

Efficient Autoregressive Inference for Transformer Probabilistic Models

TL;DR

The paper addresses the challenge of obtaining coherent joint distributions from transformer-based probabilistic models conditioned on a context set. It introduces a causal autoregressive buffer that decouples one-time context encoding from subsequent autoregressive updates, enabling efficient batched joint sampling and one-pass likelihood evaluation while preserving marginal predictive quality. The authors formalize a four-rule attention scheme, provide a unified training strategy, and demonstrate up to 20× speedups across synthetic, EEG, multisensory, and tabular tasks with minimal cost to accuracy. This approach combines the efficiency of autoregressive generation with the representational strength of set-based conditioning, making joint prediction practical for transformer probabilistic models in diverse domains.

Abstract

Transformer-based models for amortized probabilistic inference, such as neural processes, prior-fitted networks, and tabular foundation models, excel at single-pass marginal prediction. However, many real-world applications, from signal interpolation to multi-column tabular predictions, require coherent joint distributions that capture dependencies between predictions. While purely autoregressive architectures efficiently generate such distributions, they sacrifice the flexible set-conditioning that makes these models powerful for meta-learning. Conversely, the standard approach to obtain joint distributions from set-based models requires expensive re-encoding of the entire augmented conditioning set at each autoregressive step. We introduce a causal autoregressive buffer that preserves the advantages of both paradigms. Our approach decouples context encoding from updating the conditioning set. The model processes the context once and caches it. A dynamic buffer then captures target dependencies: as targets are incorporated, they enter the buffer and attend to both the cached context and previously buffered targets. This enables efficient batched autoregressive generation and one-pass joint log-likelihood evaluation. A unified training strategy allows seamless integration of set-based and autoregressive modes at minimal additional cost. Across synthetic functions, EEG signals, cognitive models, and tabular data, our method matches predictive accuracy of strong baselines while delivering up to 20 times faster joint sampling. Our approach combines the efficiency of autoregressive generative models with the representational power of set-based conditioning, making joint prediction practical for transformer-based probabilistic models.

Paper Structure

This paper contains 94 sections, 21 equations, 17 figures, 8 tables, 2 algorithms.

Figures (17)

  • Figure 1: The autoregressive buffer enables fast joint inference by eliminating redundant context re-computation.Left: Comparison of autoregressive inference strategies.Traditional autoregressive approach (top) requires re-encoding the entire augmented context set at each step when generating predictions for targets, leading to $\mathcal{O}(K(N+K)^2)$ complexity, where $N$ is the context set size and $K$ the number of targets. Our buffered approach (bottom) encodes the context $\mathcal{C}$ once and caches it. New predictions enter a causal autoregressive buffer that attends to both the static cache and previous buffer entries without re-encoding. Right: Empirical validation. We compare transformer probabilistic models with and without the buffer mechanism. Both strategies achieve comparable predictive accuracy, confirming that the buffer preserves model quality while delivering up to 20$\times$ faster sample generation at larger context sizes.
  • Figure 2: Example training mask.
  • Figure 3: Wall-clock time (log scale) for (left) sampling, (center) joint log-likelihood evaluation, and (right) a full training step, plotted as a function of the number of context points $N$. Our method demonstrates significant speedups over expressive autoregressive baselines.
  • Figure 4: Multisensory causal inference model comparison versus ground-truth. (Left) Log marginal likelihood (LML) comparison for both $\rho=1$ and $\rho=4/3$. (Right) LML difference ($\rho=4/3 - \rho=1$) comparison. Our method closely aligns with the ground-truth.
  • Figure :
  • ...and 12 more figures