Table of Contents
Fetching ...

Bayesian Optimality of In-Context Learning with Selective State Spaces

Di Zhang, Jiaqi Xing

TL;DR

This work reframes in-context learning as Bayesian sequential inference over latent state-space tasks, formalizing ICL as meta-learning across Linear Gaussian State Space Models (LG-SSMs). It proves that meta-trained selective SSMs asymptotically implement the Bayes-optimal predictor, converging to the posterior predictive mean, and establishes a computational-statistical separation from gradient-descent ERM-based transformers under temporally structured noise. The authors provide a Filter Representation Lemma showing selective SSMs can realize Kalman-filter updates, and Theorem 1 with consistency and efficiency results, along with Theorem 2 demonstrating strictly lower asymptotic risk than ERM in a correlated-noise task family. Empirical results on synthetic LG-SSMs and a character-level Markov benchmark confirm faster convergence to Bayes risk, improved sample efficiency in structured-noise settings, and more faithful latent-state tracking than linear Transformers, offering a principled basis for architecture design in sequential inference tasks.

Abstract

We propose Bayesian optimal sequential prediction as a new principle for understanding in-context learning (ICL). Unlike interpretations framing Transformers as performing implicit gradient descent, we formalize ICL as meta-learning over latent sequence tasks. For tasks governed by Linear Gaussian State Space Models (LG-SSMs), we prove a meta-trained selective SSM asymptotically implements the Bayes-optimal predictor, converging to the posterior predictive mean. We further establish a statistical separation from gradient descent, constructing tasks with temporally correlated noise where the optimal Bayesian predictor strictly outperforms any empirical risk minimization (ERM) estimator. Since Transformers can be seen as performing implicit ERM, this demonstrates selective SSMs achieve lower asymptotic risk due to superior statistical efficiency. Experiments on synthetic LG-SSM tasks and a character-level Markov benchmark confirm selective SSMs converge faster to Bayes-optimal risk, show superior sample efficiency with longer contexts in structured-noise settings, and track latent states more robustly than linear Transformers. This reframes ICL from "implicit optimization" to "optimal inference," explaining the efficiency of selective SSMs and offering a principled basis for architecture design.

Bayesian Optimality of In-Context Learning with Selective State Spaces

TL;DR

This work reframes in-context learning as Bayesian sequential inference over latent state-space tasks, formalizing ICL as meta-learning across Linear Gaussian State Space Models (LG-SSMs). It proves that meta-trained selective SSMs asymptotically implement the Bayes-optimal predictor, converging to the posterior predictive mean, and establishes a computational-statistical separation from gradient-descent ERM-based transformers under temporally structured noise. The authors provide a Filter Representation Lemma showing selective SSMs can realize Kalman-filter updates, and Theorem 1 with consistency and efficiency results, along with Theorem 2 demonstrating strictly lower asymptotic risk than ERM in a correlated-noise task family. Empirical results on synthetic LG-SSMs and a character-level Markov benchmark confirm faster convergence to Bayes risk, improved sample efficiency in structured-noise settings, and more faithful latent-state tracking than linear Transformers, offering a principled basis for architecture design in sequential inference tasks.

Abstract

We propose Bayesian optimal sequential prediction as a new principle for understanding in-context learning (ICL). Unlike interpretations framing Transformers as performing implicit gradient descent, we formalize ICL as meta-learning over latent sequence tasks. For tasks governed by Linear Gaussian State Space Models (LG-SSMs), we prove a meta-trained selective SSM asymptotically implements the Bayes-optimal predictor, converging to the posterior predictive mean. We further establish a statistical separation from gradient descent, constructing tasks with temporally correlated noise where the optimal Bayesian predictor strictly outperforms any empirical risk minimization (ERM) estimator. Since Transformers can be seen as performing implicit ERM, this demonstrates selective SSMs achieve lower asymptotic risk due to superior statistical efficiency. Experiments on synthetic LG-SSM tasks and a character-level Markov benchmark confirm selective SSMs converge faster to Bayes-optimal risk, show superior sample efficiency with longer contexts in structured-noise settings, and track latent states more robustly than linear Transformers. This reframes ICL from "implicit optimization" to "optimal inference," explaining the efficiency of selective SSMs and offering a principled basis for architecture design.
Paper Structure (27 sections, 3 theorems, 37 equations, 3 figures)

This paper contains 27 sections, 3 theorems, 37 equations, 3 figures.

Key Result

Lemma 4.1

Consider a simplified selective SSM layer with state $h_t \in \mathbb{R}^n$ updated as: where $\overline{A}_t, \overline{B}_t$ are generated by input-dependent selective networks. For any Linear Gaussian State Space Model (LG-SSM) with known parameters $\theta = (A, C, Q, R)$, there exists a parameterization of these selective networks such that, when processing a sequence $\mathcal{C} where $K_t

Figures (3)

  • Figure 1: The Two Paradigms of ICL.(Left) GD/Transformer: The context is pooled into an empirical loss landscape (grey surface). The forward pass performs gradient descent (red path) to find a minimizer $\hat{y}_{\text{GD}}$. This ignores the temporal structure. (Right) Selective SSM/Bayesian Filtering: Each observation updates an internal belief state (blue cloud) over latent variables. Prediction is the mean of the evolved belief. This process naturally accounts for temporal correlations and uncertainty. Theorem 2 shows the right paradigm achieves strictly lower risk for tasks with structured temporal noise.
  • Figure 2: Convergence to Bayes Optimality.(Left) Excess risk versus number of meta-training tasks. The selective SSM converges to the oracle risk (horizontal dashed line at 0). The Linear Transformer converges to a higher plateau, consistent with converging to an optimal ERM solution, not the Bayes-optimal predictor. (Right) Histogram of prediction errors for 1000 test sequences at the end of training. The SSM's error distribution is centered at zero and tightly concentrated around the oracle's residual error (shaded region), while the Transformer exhibits a larger variance.
  • Figure 3: Performance on HMM-Generated Text.(Top) Next-character prediction accuracy versus context length $k$. Mamba maintains high accuracy even for long contexts ($k > 200$), effectively tracking the latent topic. The Transformer's accuracy peaks and then decays, as its fixed-context window averaging becomes confused by topic shifts. (Bottom) A qualitative example. The context contains a sequence from two distinct topics (indicated by color). Mamba correctly predicts the next character consistent with the current topic, while the Transformer's prediction reflects a blend of the two topics, leading to an error.

Theorems & Definitions (6)

  • Lemma 4.1: Filter Representation Lemma
  • proof : Proof Sketch
  • Theorem 1: Asymptotic Optimality
  • proof : Proof Intuition
  • Theorem 2: Risk Separation
  • proof : Proof Strategy