Table of Contents
Fetching ...

SING: SDE Inference via Natural Gradients

Amber Hu, Henry Smith, Scott Linderman

TL;DR

SING tackles the challenging problem of posterior inference for latent SDEs by applying natural gradient variational inference to a discretized, Gaussian variational family, exploiting the exponential-family geometry to achieve fast, stable convergence. A key theoretical result shows the discretized ELBO converges to the continuous-time ELBO at a rate of $O((\Delta t)^{1/2})$, ensuring near-optimal continuous-time inference as discretization refines. The framework is extended to GP-SDEs (SING-GP) with sparse inducing points, enabling drift learning with uncertainty quantification and scalable inference on real neural data. Across synthetic benchmarks and neural data, SING outperforms prior methods in latent-state recovery, drift estimation, and robustness to discretization, while its parallel SING variant achieves substantial runtime gains. Overall, SING provides a principled, scalable tool for accurate inference in complex dynamical systems with non-conjugate structures and limited priors.

Abstract

Latent stochastic differential equation (SDE) models are important tools for the unsupervised discovery of dynamical systems from data, with applications ranging from engineering to neuroscience. In these complex domains, exact posterior inference of the latent state path is typically intractable, motivating the use of approximate methods such as variational inference (VI). However, existing VI methods for inference in latent SDEs often suffer from slow convergence and numerical instability. We propose SDE Inference via Natural Gradients (SING), a method that leverages natural gradient VI to efficiently exploit the underlying geometry of the model and variational posterior. SING enables fast and reliable inference in latent SDE models by approximating intractable integrals and parallelizing computations in time. We provide theoretical guarantees that SING approximately optimizes the intractable, continuous-time objective of interest. Moreover, we demonstrate that better state inference enables more accurate estimation of nonlinear drift functions using, for example, Gaussian process SDE models. SING outperforms prior methods in state inference and drift estimation on a variety of datasets, including a challenging application to modeling neural dynamics in freely behaving animals. Altogether, our results illustrate the potential of SING as a tool for accurate inference in complex dynamical systems, especially those characterized by limited prior knowledge and non-conjugate structure.

SING: SDE Inference via Natural Gradients

TL;DR

SING tackles the challenging problem of posterior inference for latent SDEs by applying natural gradient variational inference to a discretized, Gaussian variational family, exploiting the exponential-family geometry to achieve fast, stable convergence. A key theoretical result shows the discretized ELBO converges to the continuous-time ELBO at a rate of , ensuring near-optimal continuous-time inference as discretization refines. The framework is extended to GP-SDEs (SING-GP) with sparse inducing points, enabling drift learning with uncertainty quantification and scalable inference on real neural data. Across synthetic benchmarks and neural data, SING outperforms prior methods in latent-state recovery, drift estimation, and robustness to discretization, while its parallel SING variant achieves substantial runtime gains. Overall, SING provides a principled, scalable tool for accurate inference in complex dynamical systems with non-conjugate structures and limited priors.

Abstract

Latent stochastic differential equation (SDE) models are important tools for the unsupervised discovery of dynamical systems from data, with applications ranging from engineering to neuroscience. In these complex domains, exact posterior inference of the latent state path is typically intractable, motivating the use of approximate methods such as variational inference (VI). However, existing VI methods for inference in latent SDEs often suffer from slow convergence and numerical instability. We propose SDE Inference via Natural Gradients (SING), a method that leverages natural gradient VI to efficiently exploit the underlying geometry of the model and variational posterior. SING enables fast and reliable inference in latent SDE models by approximating intractable integrals and parallelizing computations in time. We provide theoretical guarantees that SING approximately optimizes the intractable, continuous-time objective of interest. Moreover, we demonstrate that better state inference enables more accurate estimation of nonlinear drift functions using, for example, Gaussian process SDE models. SING outperforms prior methods in state inference and drift estimation on a variety of datasets, including a challenging application to modeling neural dynamics in freely behaving animals. Altogether, our results illustrate the potential of SING as a tool for accurate inference in complex dynamical systems, especially those characterized by limited prior knowledge and non-conjugate structure.

Paper Structure

This paper contains 77 sections, 166 equations, 10 figures.

Figures (10)

  • Figure 1: An overview of SING. A: In the generative model, a low-dimensional SDE gives rise to noisy, conditionally independent observations at timestamps $\{t_i\}_{i=1}^n$. B: SING leverages NGVI to perform fast and reliable approximate posterior inference in latent SDE models. NGVI exploits the geometry of the model by preconditioning updates by an inverse Fisher information matrix, often leading to faster convergence than vanilla gradient ascent. C: On a discretized time grid $\boldsymbol{\tau}$, the variational posterior is a multivariate Gaussian distribution with a block tridiagonal precision matrix.
  • Figure 2: (Top row) We apply SING to a LDS with Gaussian observations. A: True latents are sampled from a LDS characterized by a stable spiral. B: Observations are 10-dimensional Gaussian variables. C: True vs. inferred latents on an example trial, with 95% posterior credible intervals. D: Comparison between SING and several baselines over iterations. (Bottom row) We apply SING to simulated place cell activity, where both the prior and observation models are nonlinear. E: True latents are sampled from a Van der Pol oscillator and represent trajectories through 2D space. Tuning curves modeled by radial basis functions represent expected firing rates at each location in latent space. F: Observations are Poisson spike counts for 8 neurons. G, H: See C, D.
  • Figure 3: A comparison of SING and VDP for drift estimation. A: True latent trajectories evolve according to the 2D Duffing equation. B: Posterior mean (arrows) and variance (shading) corresponding to a GP prior drift with RBF kernel. SING-GP places high posterior uncertainty in regions unseen by the (true) latent trajectories. C, D: Latent RMSE and dynamics RMSE for SING and VDP across three classes of prior drift (GP, neural-SDE, polynomial basis). SING consistently outperforms VDP for hyperparameter learning, across all three drift classes. E: Dynamics RMSE for SING and VDP as a function of the grid size $\Delta t$ for the neural network drift.
  • Figure 4: (Top) Recovered latent trajectories in the first 3 dimensions of a 50-dimensional embedded Lorenz attractor. (Bottom) Comparison of latents RMSE between Monte Carlo- and quadrature-based approximations of expectations. Monte Carlo results are averaged over 5 random seeds, with negligible standard errors (omitted).
  • Figure 5: Runtime comparisons between parallelized SING and its sequential counterpart.
  • ...and 5 more figures

Theorems & Definitions (12)

  • proof
  • proof
  • proof
  • proof
  • proof : Proof of \ref{['thm:ELBO-converge']}
  • proof
  • proof
  • proof
  • proof
  • proof
  • ...and 2 more