Table of Contents
Fetching ...

KalMamba: Towards Efficient Probabilistic State Space Models for RL under Uncertainty

Philipp Becker, Niklas Freymuth, Gerhard Neumann

TL;DR

KalMamba is an efficient architecture to learn representations for RL that combines the strengths of probabilistic SSMs with the scalability of deterministic SSMs, and competes with state-of-the-art SSM approaches in RL while significantly improving computational efficiency, especially on longer interaction sequences.

Abstract

Probabilistic State Space Models (SSMs) are essential for Reinforcement Learning (RL) from high-dimensional, partial information as they provide concise representations for control. Yet, they lack the computational efficiency of their recent deterministic counterparts such as S4 or Mamba. We propose KalMamba, an efficient architecture to learn representations for RL that combines the strengths of probabilistic SSMs with the scalability of deterministic SSMs. KalMamba leverages Mamba to learn the dynamics parameters of a linear Gaussian SSM in a latent space. Inference in this latent space amounts to standard Kalman filtering and smoothing. We realize these operations using parallel associative scanning, similar to Mamba, to obtain a principled, highly efficient, and scalable probabilistic SSM. Our experiments show that KalMamba competes with state-of-the-art SSM approaches in RL while significantly improving computational efficiency, especially on longer interaction sequences.

KalMamba: Towards Efficient Probabilistic State Space Models for RL under Uncertainty

TL;DR

KalMamba is an efficient architecture to learn representations for RL that combines the strengths of probabilistic SSMs with the scalability of deterministic SSMs, and competes with state-of-the-art SSM approaches in RL while significantly improving computational efficiency, especially on longer interaction sequences.

Abstract

Probabilistic State Space Models (SSMs) are essential for Reinforcement Learning (RL) from high-dimensional, partial information as they provide concise representations for control. Yet, they lack the computational efficiency of their recent deterministic counterparts such as S4 or Mamba. We propose KalMamba, an efficient architecture to learn representations for RL that combines the strengths of probabilistic SSMs with the scalability of deterministic SSMs. KalMamba leverages Mamba to learn the dynamics parameters of a linear Gaussian SSM in a latent space. Inference in this latent space amounts to standard Kalman filtering and smoothing. We realize these operations using parallel associative scanning, similar to Mamba, to obtain a principled, highly efficient, and scalable probabilistic SSM. Our experiments show that KalMamba competes with state-of-the-art SSM approaches in RL while significantly improving computational efficiency, especially on longer interaction sequences.
Paper Structure (15 sections, 4 equations, 9 figures, 3 tables)

This paper contains 15 sections, 4 equations, 9 figures, 3 tables.

Figures (9)

  • Figure 1: Overview of KalMamba. The observation-action sequences are first fed through a dynamics backbone built on Mambagu2023mamba to learn a linear dynamics model for each step. KalMamba then uses time-parallel Kalman filtering sarkka2020temporal to infer filtered beliefs $q(\boldsymbol{\mathrm{z}}_t | \boldsymbol{\mathrm{o}}_{\leq t}, \boldsymbol{\mathrm{a}}_{\leq t-1})$ which can be used for control with a Soft Actor Critic (SAC)haarnoja2018sac. For model training, KalMamba employs an additional time-parallel Kalman smoothing step to obtain smoothed beliefs $q(\boldsymbol{\mathrm{z}}_t | \boldsymbol{\mathrm{o}}_{\leq T}, \boldsymbol{\mathrm{a}}_{\leq T})$. These beliefs allow training a model that excels in modeling uncertainties due to a tight variational lower bound becker2022uncertainty. Crucially, the smoothing step does not introduce trainable model parameters, enabling the direct use of the filtered beliefs for downstream RL policy training and execution.
  • Figure 2: Schematic of the Mamba gu2023mamba based backbone to learn the system dynamics. It shares the inference model's encoder $\phi(\boldsymbol{\mathrm{o}}_t)$ and intermediate representation $\boldsymbol{\mathrm{w}}_t$. Each $\boldsymbol{\mathrm{w}}_t$ is then concatenated to the previous action $\boldsymbol{\mathrm{a}}_{t-1}$, fed through a small Neural Network (NN) and given to Mamba model which accumulates information over time and emits a representation $\boldsymbol{\mathrm{m}}_t(\boldsymbol{\mathrm{o}}_{t \leq}, \boldsymbol{\mathrm{a}}_{\leq t-1})$ containing the same information as the filtered belief $q(\boldsymbol{\mathrm{z}}_t | \boldsymbol{\mathrm{o}}_{t \leq}, \boldsymbol{\mathrm{a}}_{\leq t-1})$. We then concatenate each $\boldsymbol{\mathrm{m}}_t$ with the current action $\boldsymbol{\mathrm{a}}_t$ and use another small NN to compute the dynamics parameters $\boldsymbol{\mathrm{A}}_t, \boldsymbol{\mathrm{b}}_t$ and $\boldsymbol{\mathrm{\Sigma}}_t$. This scheme allows us to use the intermediate representation $\boldsymbol{\mathrm{m}}_t$ for regularization and we regularize it towards the filtered belief's mean using a Mahalanobis regularizer (c.f. \ref{['eq:mahal_reg']}). Finally, the small NNs include Monte-Carlo Dropout gal2016dropout to model epistemic uncertainty.
  • Figure 3: Aggregated expected returns for image-based observations. (Left) KalMamba is slightly worse but overall competitive with the different baselines. Combining either baseline SSM with SAC matches or exceeds the performance of DreamerV3. (Right) Using Mamba to learn the dynamics is crucial for good model performance. Similarly, both Monte-Carlo Dropout and the regularization loss of Equation \ref{['eq:mahal_reg']} stabilize the training process and lead to higher expected returns.
  • Figure 4: Aggregated expected returns for the state-based noisy tasks. KalMamba clearly outperforms the RSSM while almost matching the VRKN's performance. Naively using SAC is insufficient, which testifies to the increased difficulty due to the noise.
  • Figure 5: Wall-clock time evaluations on the state-based noisy walker-walk for KalMamba , the RSSM, and the VRKN for different training context lengths for $1$ million environment steps or up to $24$ hours. This time limitation only affected the VRKN training for $256$ steps, which reached $650$ thousand steps after $24$ hours. While all methods work well for short sequences of length $32$ (Top Left), the efficient parallelization of KalMamba allows it to scale gracefully to and even improve performance for longer sequences of up to $256$ steps, where the other methods fail (Bottom Right).
  • ...and 4 more figures