Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression
Liangzu Peng, Aditya Chattopadhyay, Luca Zancato, Elvis Nunez, Wei Xia, Stefano Soatto
TL;DR
This work introduces Gated KalmaNet (GKA), a KF-inspired linear memory layer for large-scale sequence modeling that retains full past information with constant memory and linear compute. By reframing the Kalman Filter recurrence as a test-time ridge regression problem, GKA uses adaptive regularization and input-dependent gating, solved efficiently via Chebyshev Iteration and a hardware-aware chunked implementation. GKA achieves state-of-the-art recall among linear fading-memory layers on synthetic MQAR tasks and provides substantial gains on long-context benchmarks (RAG and LongQA) up to 128k tokens, with competitive throughput. The method also establishes connections to existing SSM layers and demonstrates potential for hybridizing with Attention, while outlining future avenues for non-linear extensions and faster implementations.
Abstract
As efficient alternatives to softmax Attention, linear State-Space Models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall-oriented tasks. We propose Gated KalmaNet (GKA), a layer that accounts for the full past while maintaining SSM-style efficiency. We ground our approach in the Kalman Filter (KF) framework, which provides a principled solution for optimal inference in dynamical systems. We show that several existing SSM layers (DeltaNet, Gated DeltaNet, and Kimi Delta Attention) are approximations to the KF recurrence that assume identity error covariance, thereby ignoring how past measurements (keys and values) should optimally influence state updates. In contrast, GKA computes the exact Kalman gain by maintaining the full error covariance. Under a steady-state assumption that enables parallelization, this reduces to solving an online ridge regression problem with constant memory and linear compute cost. A critical insight is that standard KF equations are numerically unstable in low-precision environments (like bfloat16) and hard to parallelize on modern hardware. We address this through: (1) adaptive regularization with input-dependent gating to control the condition number of the ridge regression for numerical stability, and (2) Chebyshev Iteration, which we show is more stable than conventional iterative solvers in low-precision settings. We further develop hardware-aware chunk-wise kernels to enable efficient training. Empirically, GKA outperforms existing SSM layers (like Mamba2 and Gated DeltaNet) on short-context tasks and achieves more than 10\% relative improvement on long-context RAG and LongQA tasks up to 128k tokens.
