Neural Stochastic Flows: Solver-Free Modelling and Inference for SDE Solutions
Naoki Kiyohara, Edward Johns, Yingzhen Li
TL;DR
The paper introduces Neural Stochastic Flows (NSFs) to learn the weak solutions of SDEs directly as conditional transition densities, enabling solver-free sampling between arbitrary time points. By using conditional normalising flows and a flow-consistency regularisation based on forward and reverse KL bounds with a bridge distribution, NSFs preserve key stochastic-flow properties while providing tractable log-densities. The latent extension (Latent NSFs) embeds NSFs within variational state-space models to handle irregular sampling and partial observations, incorporating skip-ahead KLs to mitigate long-horizon error accumulation. Across stochastic Lorenz, CMU Motion Capture, and stochastic Moving MNIST, NSFs achieve distributional accuracy comparable to solver-based approaches but with substantial computational speedups, including state-of-the-art extrapolation in latent settings. This solver-free framework promises real-time stochastic modelling for irregular time series and digital twins, with potential extensions to action-conditioned control and diffusion-model hybrids.
Abstract
Stochastic differential equations (SDEs) are well suited to modelling noisy and irregularly sampled time series found in finance, physics, and machine learning. Traditional approaches require costly numerical solvers to sample between arbitrary time points. We introduce Neural Stochastic Flows (NSFs) and their latent variants, which directly learn (latent) SDE transition laws using conditional normalising flows with architectural constraints that preserve properties inherited from stochastic flows. This enables one-shot sampling between arbitrary states and yields up to two orders of magnitude speed-ups at large time gaps. Experiments on synthetic SDE simulations and on real-world tracking and video data show that NSFs maintain distributional accuracy comparable to numerical approaches while dramatically reducing computation for arbitrary time-point sampling.
