Table of Contents
Fetching ...

Unsupervised Continual Learning for Amortized Bayesian Inference

Aayush Mishra, Šimon Kucharský, Paul-Christian Bürkner

TL;DR

This work proposes a continual learning framework for ABI that decouples simulation-based pre-training from unsupervised sequential SC fine-tuning on real-world data, and introduces two adaptation strategies to address the challenge of catastrophic forgetting.

Abstract

Amortized Bayesian Inference (ABI) enables efficient posterior estimation using generative neural networks trained on simulated data, but often suffers from performance degradation under model misspecification. While self-consistency (SC) training on unlabeled empirical data can enhance network robustness, current approaches are limited to static, single-task settings and fail to handle sequentially arriving data or distribution shifts. We propose a continual learning framework for ABI that decouples simulation-based pre-training from unsupervised sequential SC fine-tuning on real-world data. To address the challenge of catastrophic forgetting, we introduce two adaptation strategies: (1) SC with episodic replay, utilizing a memory buffer of past observations, and (2) SC with elastic weight consolidation, which regularizes updates to preserve task-critical parameters. Across three diverse case studies, our methods significantly mitigate forgetting and yield posterior estimates that outperform standard simulation-based training, achieving estimates closer to MCMC reference, providing a viable path for trustworthy ABI across a range of different tasks.

Unsupervised Continual Learning for Amortized Bayesian Inference

TL;DR

This work proposes a continual learning framework for ABI that decouples simulation-based pre-training from unsupervised sequential SC fine-tuning on real-world data, and introduces two adaptation strategies to address the challenge of catastrophic forgetting.

Abstract

Amortized Bayesian Inference (ABI) enables efficient posterior estimation using generative neural networks trained on simulated data, but often suffers from performance degradation under model misspecification. While self-consistency (SC) training on unlabeled empirical data can enhance network robustness, current approaches are limited to static, single-task settings and fail to handle sequentially arriving data or distribution shifts. We propose a continual learning framework for ABI that decouples simulation-based pre-training from unsupervised sequential SC fine-tuning on real-world data. To address the challenge of catastrophic forgetting, we introduce two adaptation strategies: (1) SC with episodic replay, utilizing a memory buffer of past observations, and (2) SC with elastic weight consolidation, which regularizes updates to preserve task-critical parameters. Across three diverse case studies, our methods significantly mitigate forgetting and yield posterior estimates that outperform standard simulation-based training, achieving estimates closer to MCMC reference, providing a viable path for trustworthy ABI across a range of different tasks.
Paper Structure (37 sections, 20 equations, 16 figures, 4 algorithms)

This paper contains 37 sections, 20 equations, 16 figures, 4 algorithms.

Figures (16)

  • Figure 1: Experiment 1. MMD ratio in log scale across CL tasks (0–9) for different methods. Boxes summarize variability across subsets. Dashed line marks parity with simulation-based (SB) baseline (ratio = 1). Naive SC shows catastrophic forgetting in CL setting whereas our proposed methods mitigate forgetting and provide better posterior estimates compared to SB and Naive SC. Test-time SC also gives accurate posterior estimates.
  • Figure 2: Experiment 2. MMD ratio in log scale for different methods aggregated over fifteen CL tasks. Dashed line marks parity with simulation-based baseline (ratio = 1). Naive SC shows catastrophic forgetting in CL setting whereas our proposed methods mitigate forgetting and perform better than both SB baseline and naive SC.
  • Figure 3: Experiment 3. MMD ratio in log scale across CL tasks (0–8) for the four RDM parameters for different methods. Boxes summarize variability over datasets. Dashed line marks parity with simulation-based baseline (ratio = 1). All SC methods outperform SB-only training. Catastrophic forgetting is not evident.
  • Figure 4: Experiment 1. MMD ratio in log scale across CL tasks (0–9) for SC-EWC with different values of EWC hyperparameter $\lambda$. Boxes summarize variability across subsets. Dashed line marks parity with simulation-based (SB) baseline (ratio = 1). $\lambda = 10^{2}$ shows forgetting while other variants perform similarly.
  • Figure 5: Experiment 1. MMD ratio in log scale across CL tasks (0–9) for SC-ER-EWC with different values of EWC hyperparameter $\lambda$. Boxes summarize variability across subsets. Dashed line marks parity with simulation-based (SB) baseline (ratio = 1). The performance remains largely stable even for smaller $\lambda$ values, indicating that the robustness improvements are primarily attributable to episodic replay. The addition of EWC provides only marginal benefits in this setting.
  • ...and 11 more figures