Table of Contents
Fetching ...

Kalman Linear Attention: Parallel Bayesian Filtering For Efficient Language Modelling and State Tracking

Vaisakh Shaj, Cameron Barker, Aidan Scannell, Andras Szecsenyi, Elliot J. Crowley, Amos Storkey

TL;DR

Kalman Linear Attention (KLA) reframes sequence modelling as Bayesian filtering to enable long-context language modelling with explicit uncertainty. By reparameterising the Kalman filter in information form, KLA achieves associative Möbius-style updates that are parallelisable, matching the efficiency of sub-quadratic attention methods while adding nonlinear, uncertainty-driven gates. Empirically, KLA competes with or surpasses state-of-the-art baselines on synthetic LM benchmarks, long-context recall, and hard state-tracking tasks such as the A5 permutation problem, while offering interpretable posterior uncertainty dynamics. The work demonstrates that probabilistic filtering can provide stronger state-tracking capabilities without sacrificing scan-parallel scalability, suggesting a promising direction for scalable, uncertainty-aware language modelling. Overall, KLA broadens the toolbox for efficient sequence modelling by unifying probabilistic inference with modern parallel architectures and explicit belief-state representations.

Abstract

State-space language models such as Mamba and gated linear attention (GLA) offer efficient alternatives to transformers due to their linear complexity and parallel training, but often lack the expressivity and robust state-tracking needed for complex reasoning. We address these limitations by reframing sequence modelling through a probabilistic lens, using Bayesian filters as a core primitive. While classical filters such as Kalman filters provide principled state estimation and uncertainty tracking, they are typically viewed as inherently sequential. We show that reparameterising the Kalman filter in information form enables its updates to be computed via an associative scan, allowing efficient parallel training. Building on this insight, we introduce the Kalman Linear Attention (KLA) layer, a neural sequence-modelling primitive that performs time-parallel probabilistic inference while maintaining explicit belief-state uncertainty. KLA offers strictly more expressive nonlinear updates and gating than GLA variants while retaining their computational advantages. On language modelling tasks, KLA matches or outperforms modern SSMs and GLAs across representative discrete token-manipulation and state-tracking benchmarks.

Kalman Linear Attention: Parallel Bayesian Filtering For Efficient Language Modelling and State Tracking

TL;DR

Kalman Linear Attention (KLA) reframes sequence modelling as Bayesian filtering to enable long-context language modelling with explicit uncertainty. By reparameterising the Kalman filter in information form, KLA achieves associative Möbius-style updates that are parallelisable, matching the efficiency of sub-quadratic attention methods while adding nonlinear, uncertainty-driven gates. Empirically, KLA competes with or surpasses state-of-the-art baselines on synthetic LM benchmarks, long-context recall, and hard state-tracking tasks such as the A5 permutation problem, while offering interpretable posterior uncertainty dynamics. The work demonstrates that probabilistic filtering can provide stronger state-tracking capabilities without sacrificing scan-parallel scalability, suggesting a promising direction for scalable, uncertainty-aware language modelling. Overall, KLA broadens the toolbox for efficient sequence modelling by unifying probabilistic inference with modern parallel architectures and explicit belief-state representations.

Abstract

State-space language models such as Mamba and gated linear attention (GLA) offer efficient alternatives to transformers due to their linear complexity and parallel training, but often lack the expressivity and robust state-tracking needed for complex reasoning. We address these limitations by reframing sequence modelling through a probabilistic lens, using Bayesian filters as a core primitive. While classical filters such as Kalman filters provide principled state estimation and uncertainty tracking, they are typically viewed as inherently sequential. We show that reparameterising the Kalman filter in information form enables its updates to be computed via an associative scan, allowing efficient parallel training. Building on this insight, we introduce the Kalman Linear Attention (KLA) layer, a neural sequence-modelling primitive that performs time-parallel probabilistic inference while maintaining explicit belief-state uncertainty. KLA offers strictly more expressive nonlinear updates and gating than GLA variants while retaining their computational advantages. On language modelling tasks, KLA matches or outperforms modern SSMs and GLAs across representative discrete token-manipulation and state-tracking benchmarks.
Paper Structure (84 sections, 9 theorems, 34 equations, 15 figures, 10 tables, 1 algorithm)

This paper contains 84 sections, 9 theorems, 34 equations, 15 figures, 10 tables, 1 algorithm.

Key Result

Theorem 1

Let $\textcolor{UncBlue}{\boldsymbol{\lambda}_t}$ be the posterior precision at time $t$ in the diagonal linear--Gaussian model eq:ssm_dyn--eq:ssm_obs. Define $\textcolor{UncBlue}{\boldsymbol{\phi}_t \coloneqq \mathbf{k}_t^{2}\odot\boldsymbol{\Lambda}_t^{\mathrm v}}$. Then the map $\boldsymbol{\lamb

Figures (15)

  • Figure 1: Minimum number of layers required to solve the $A_5$ (alternating group on 5 elements) permutation composition task merrill2024illusion. KLA's fractional linear updates fall between a fully nonlinear RNN and linear SSMs/transformers, requiring only 1 or 2 layers to solve the task without loss of parallelism.
  • Figure 2: From OU dynamics to parallel inference. (Top) Continuous-time OU prior. (Middle) Discrete linear-Gaussian SSM. (Bottom) Möbius scan for parallel posterior state estimation.
  • Figure 3: Block architecture. The block follows the fused-MLP design of Mamba, with the Kalman Filter as a drop-in replacement for any SSM/Attention primitive.
  • Figure 4: Ablation of OU prior dynamics and discretisation (\ref{['sec:stochastic_dynamics']}) on Selective Copy ($T=256$). OU discretisation improves accuracy and learning stability, especially for deeper models.
  • Figure 5: Training-time runtime scaling. Wall-clock runtime of KLA implementations across sequence lengths. Torch Scan uses torch._higher_order_ops.associative_scan; Triton Scan uses custom forward/backward kernels.
  • ...and 10 more figures

Theorems & Definitions (16)

  • Theorem 1: Precision Update as a Möbius Transformation
  • Corollary 1: Precision Updates via Parallel Prefix Scan
  • Theorem 2: Mean Update as Affine Transformations
  • Corollary 2: Mean Updates via Parallel Prefix Scan
  • Theorem 2: Precision Update as a Möbius Transformation
  • proof : Proof of \ref{['Th:Mobius']}
  • Corollary 2: Precision Updates via Parallel Prefix Scan
  • proof : Proof of \ref{['cor:parallel-scan']}
  • Remark : Practical Implementation
  • Theorem 2: Mean Update as Affine Transformations
  • ...and 6 more