Table of Contents
Fetching ...

Bayesian Meta-Reinforcement Learning with Laplace Variational Recurrent Networks

Joery A. de Vries, Jinke He, Mathijs M. de Weerdt, Matthijs T. J. Spaan

TL;DR

This paper reframes memory-based meta-reinforcement learning through a Bayesian lens and introduces Laplace Variational Recurrent Networks (Laplace VRNNs) to obtain posterior uncertainty without redesigning existing architectures. By applying the Laplace approximation to the latent-variable posterior, it yields a Gaussian distribution over environment representations that can be used for posterior predictive decision-making and uncertainty estimation, even for non-Bayesian baseline agents. The authors formulate a probabilistic graphical model aligned with memory-based meta-RL, derive practical lower bounds, and show that Laplace VRNNs can match variational baselines with far fewer learnable parameters. Empirical results in supervised and reinforcement learning tasks demonstrate useful posterior statistics and robust performance, though the approach relies on simplifying assumptions (e.g., Gaussian posteriors and Jacobian-based curvature) and incurs computational costs in Jacobian calculations. Overall, the method provides a lightweight, architecture-preserving avenue for uncertainty quantification and more robust decision-making in meta-RL.

Abstract

Meta-reinforcement learning trains a single reinforcement learning agent on a distribution of tasks to quickly generalize to new tasks outside of the training set at test time. From a Bayesian perspective, one can interpret this as performing amortized variational inference on the posterior distribution over training tasks. Among the various meta-reinforcement learning approaches, a common method is to represent this distribution with a point-estimate using a recurrent neural network. We show how one can augment this point estimate to give full distributions through the Laplace approximation, either at the start of, during, or after learning, without modifying the base model architecture. With our approximation, we are able to estimate distribution statistics (e.g., the entropy) of non-Bayesian agents and observe that point-estimate based methods produce overconfident estimators while not satisfying consistency. Furthermore, when comparing our approach to full-distribution based learning of the task posterior, our method performs on par with variational baselines while having much fewer parameters.

Bayesian Meta-Reinforcement Learning with Laplace Variational Recurrent Networks

TL;DR

This paper reframes memory-based meta-reinforcement learning through a Bayesian lens and introduces Laplace Variational Recurrent Networks (Laplace VRNNs) to obtain posterior uncertainty without redesigning existing architectures. By applying the Laplace approximation to the latent-variable posterior, it yields a Gaussian distribution over environment representations that can be used for posterior predictive decision-making and uncertainty estimation, even for non-Bayesian baseline agents. The authors formulate a probabilistic graphical model aligned with memory-based meta-RL, derive practical lower bounds, and show that Laplace VRNNs can match variational baselines with far fewer learnable parameters. Empirical results in supervised and reinforcement learning tasks demonstrate useful posterior statistics and robust performance, though the approach relies on simplifying assumptions (e.g., Gaussian posteriors and Jacobian-based curvature) and incurs computational costs in Jacobian calculations. Overall, the method provides a lightweight, architecture-preserving avenue for uncertainty quantification and more robust decision-making in meta-RL.

Abstract

Meta-reinforcement learning trains a single reinforcement learning agent on a distribution of tasks to quickly generalize to new tasks outside of the training set at test time. From a Bayesian perspective, one can interpret this as performing amortized variational inference on the posterior distribution over training tasks. Among the various meta-reinforcement learning approaches, a common method is to represent this distribution with a point-estimate using a recurrent neural network. We show how one can augment this point estimate to give full distributions through the Laplace approximation, either at the start of, during, or after learning, without modifying the base model architecture. With our approximation, we are able to estimate distribution statistics (e.g., the entropy) of non-Bayesian agents and observe that point-estimate based methods produce overconfident estimators while not satisfying consistency. Furthermore, when comparing our approach to full-distribution based learning of the task posterior, our method performs on par with variational baselines while having much fewer parameters.

Paper Structure

This paper contains 38 sections, 3 theorems, 31 equations, 13 figures, 2 tables.

Key Result

Lemma 1

We can write $p(Z | \{X_i\}^n_{i=1}) = \frac{1}{p(Z)^{n-1}} \prod_{i=1}^n p(Z | X_i)$ iff $X_i \perp \!\!\! \perp X_j, \forall j \ne i$.

Figures (13)

  • Figure 1: Final performance on the zero-shot regression task in terms of predictive cross-entropy, this should decrease over time. The left column shows results for complete model training, the middle and right columns perform model pre-training with the RNN (black). The middle column includes parameter finetuning, the right column does not. The blue dashed line indicates the training cut-off ($T=50$).
  • Figure 2: Evolution of summary statistics for the posterior model during testing. The top row shows the KL-divergences between consecutive posteriors $q_t$ and $q_{t+1}$, and the bottom row shows the model entropy over time. In principle, we expect all lines to decrease gradually with more observations. When applying our Laplace approximation with summed covariances (green) after deterministic pre-training (right-column), we see that the posterior becomes more and more confident but does not converge to a stable distribution.
  • Figure 3: Average return curves during training for the Reinforcement Learning experiments. The dashed and solid lines (Num $Z$) indicate the number of Monte-Carlo samples used for the posterior model during training inside the modified lower bound of Eq. \ref{['eq:rl_lb']}, to validate off-policy robustness in our loss. As expected, the deterministic RNN performs best, but our Laplace VRNN also outperforms the baseline VRNN.
  • Figure 4: Visualization of sampled tasks we evaluated our method on. 1) Zero-shot learning of a function (left), 2) learning a stochastic best-arm selection algorithm (middle), and 3) learning a deterministic grid exploration agent (right).
  • Figure 5: Progression of the predictive error during supervised model training of all ablations. Ablations are averaged over parameter groups as indicated by the legend. This figure does not show the finetuning results. To reduce computation time for subsequent results, we picked $\beta=0.01$ for the reinforcement learning task ablations.
  • ...and 8 more figures

Theorems & Definitions (6)

  • Lemma 1
  • proof
  • Proposition 1
  • proof
  • Proposition 2
  • proof