Table of Contents
Fetching ...

Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression

Liangzu Peng, Aditya Chattopadhyay, Luca Zancato, Elvis Nunez, Wei Xia, Stefano Soatto

TL;DR

This work introduces Gated KalmaNet (GKA), a KF-inspired linear memory layer for large-scale sequence modeling that retains full past information with constant memory and linear compute. By reframing the Kalman Filter recurrence as a test-time ridge regression problem, GKA uses adaptive regularization and input-dependent gating, solved efficiently via Chebyshev Iteration and a hardware-aware chunked implementation. GKA achieves state-of-the-art recall among linear fading-memory layers on synthetic MQAR tasks and provides substantial gains on long-context benchmarks (RAG and LongQA) up to 128k tokens, with competitive throughput. The method also establishes connections to existing SSM layers and demonstrates potential for hybridizing with Attention, while outlining future avenues for non-linear extensions and faster implementations.

Abstract

As efficient alternatives to softmax Attention, linear State-Space Models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall-oriented tasks. We propose Gated KalmaNet (GKA), a layer that accounts for the full past while maintaining SSM-style efficiency. We ground our approach in the Kalman Filter (KF) framework, which provides a principled solution for optimal inference in dynamical systems. We show that several existing SSM layers (DeltaNet, Gated DeltaNet, and Kimi Delta Attention) are approximations to the KF recurrence that assume identity error covariance, thereby ignoring how past measurements (keys and values) should optimally influence state updates. In contrast, GKA computes the exact Kalman gain by maintaining the full error covariance. Under a steady-state assumption that enables parallelization, this reduces to solving an online ridge regression problem with constant memory and linear compute cost. A critical insight is that standard KF equations are numerically unstable in low-precision environments (like bfloat16) and hard to parallelize on modern hardware. We address this through: (1) adaptive regularization with input-dependent gating to control the condition number of the ridge regression for numerical stability, and (2) Chebyshev Iteration, which we show is more stable than conventional iterative solvers in low-precision settings. We further develop hardware-aware chunk-wise kernels to enable efficient training. Empirically, GKA outperforms existing SSM layers (like Mamba2 and Gated DeltaNet) on short-context tasks and achieves more than 10\% relative improvement on long-context RAG and LongQA tasks up to 128k tokens.

Gated KalmaNet: A Fading Memory Layer Through Test-Time Ridge Regression

TL;DR

This work introduces Gated KalmaNet (GKA), a KF-inspired linear memory layer for large-scale sequence modeling that retains full past information with constant memory and linear compute. By reframing the Kalman Filter recurrence as a test-time ridge regression problem, GKA uses adaptive regularization and input-dependent gating, solved efficiently via Chebyshev Iteration and a hardware-aware chunked implementation. GKA achieves state-of-the-art recall among linear fading-memory layers on synthetic MQAR tasks and provides substantial gains on long-context benchmarks (RAG and LongQA) up to 128k tokens, with competitive throughput. The method also establishes connections to existing SSM layers and demonstrates potential for hybridizing with Attention, while outlining future avenues for non-linear extensions and faster implementations.

Abstract

As efficient alternatives to softmax Attention, linear State-Space Models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall-oriented tasks. We propose Gated KalmaNet (GKA), a layer that accounts for the full past while maintaining SSM-style efficiency. We ground our approach in the Kalman Filter (KF) framework, which provides a principled solution for optimal inference in dynamical systems. We show that several existing SSM layers (DeltaNet, Gated DeltaNet, and Kimi Delta Attention) are approximations to the KF recurrence that assume identity error covariance, thereby ignoring how past measurements (keys and values) should optimally influence state updates. In contrast, GKA computes the exact Kalman gain by maintaining the full error covariance. Under a steady-state assumption that enables parallelization, this reduces to solving an online ridge regression problem with constant memory and linear compute cost. A critical insight is that standard KF equations are numerically unstable in low-precision environments (like bfloat16) and hard to parallelize on modern hardware. We address this through: (1) adaptive regularization with input-dependent gating to control the condition number of the ridge regression for numerical stability, and (2) Chebyshev Iteration, which we show is more stable than conventional iterative solvers in low-precision settings. We further develop hardware-aware chunk-wise kernels to enable efficient training. Empirically, GKA outperforms existing SSM layers (like Mamba2 and Gated DeltaNet) on short-context tasks and achieves more than 10\% relative improvement on long-context RAG and LongQA tasks up to 128k tokens.

Paper Structure

This paper contains 42 sections, 4 theorems, 83 equations, 10 figures, 9 tables, 2 algorithms.

Key Result

Lemma 1

With $\lambda_t= a \| H_t \|_{\text{F}}$, $w_t=\frac{ 2a \cdot (x_t^*)^\top dq_t}{\| H_t \|_{\text{F}} }$, we have

Figures (10)

  • Figure 1: CH converges with smaller errors than CG and is more numerically stable. Convergence of different methods in residual norms during the forward pass with batch size $8$, sequence length 2048, 8 heads, head dimension 128 (a), and relative gradient differences from the exact solver (torch.linalg.solve) to CG (b, c) or CH (d, e). The backward pass is via implicit differentiation (impl) or torch.autograd (auto); cf. \ref{['table:implicit-diff']}. In (b, d) the gradients are those of $[q_t,k_t]$; in (c, e) the gradients are those of network weights.
  • Figure 2: Our GKA block. Blue refers to established practices in the literature with the solid circles denote $\ell_2$ normalization. Green components (CH and $\alpha$-connection) are our proposals.
  • Figure 3: MQAR results (a) Each plot corresponds to a particular sequence length and number of key-value pairs for the model to memorize. Runtime (b) Runtimes are for a single forward + backward pass (8 heads, head dim $128$, batch size $4$, averaged over 20 runs).
  • Figure 4: Long Context Performance up to 128k tokens. GKA achieves strong RAG and LongQA capabilities, outperforming all baselines by 10% in relative improvement. Interestingly, we observe that there is no clear winner Synthetic Recall. All models struggle to perform better than random chance on ICL.
  • Figure 5: (a) The theoretical lower and upper bounds for the values of the divisor $b_i$ that arise in reversing Chebyshev \ref{['eq:xi_backward']}; (b) The empirical lower and upper bounds for the divisor that arises in reversing CG.
  • ...and 5 more figures

Theorems & Definitions (8)

  • Lemma 1
  • Lemma 2
  • Lemma 3
  • proof
  • Lemma 4
  • proof
  • Remark 1
  • Remark 2