In-Context Learning of Stochastic Differential Equations with Foundation Inference Models
Patrick Seifner, Kostadin Cvejoski, David Berghaus, Cesar Ojeda, Ramses J. Sanchez
TL;DR
This work tackles data-driven discovery of drift and diffusion in stochastic differential equations by introducing FIM-SDE, a foundation inference model pretrained on synthetic SDEs with polynomial drift and diffusion. Leveraging a transformer-based neural-operator architecture, FIM-SDE achieves zero-shot in-context estimation of $\mathbf{f}$ and $\mathbf{G}$ for low-dimensional SDEs and can be rapidly finetuned to target data, often outperforming symbolic, GP, and neural SDE baselines after only a small amount of additional training. The approach demonstrates robust generalization across canonical systems and real-world datasets, with notable efficiency gains (e.g., up to 50x faster fine-tuning) and strong performance under irregular sampling and noise. This foundation-model–like solution has potential to accelerate automated scientific discovery by enabling quick, data-efficient inference of dynamical laws from limited or noisy observations.
Abstract
Stochastic differential equations (SDEs) describe dynamical systems where deterministic flows, governed by a drift function, are superimposed with random fluctuations, dictated by a diffusion function. The accurate estimation (or discovery) of these functions from data is a central problem in machine learning, with wide application across the natural and social sciences. Yet current solutions either rely heavily on prior knowledge of the dynamics or involve intricate training procedures. We introduce FIM-SDE (Foundation Inference Model for SDEs), a pretrained recognition model that delivers accurate in-context (or zero-shot) estimation of the drift and diffusion functions of low-dimensional SDEs, from noisy time series data, and allows rapid finetuning to target datasets. Leveraging concepts from amortized inference and neural operators, we (pre)train FIM-SDE in a supervised fashion to map a large set of noisy, discretely observed SDE paths onto the space of drift and diffusion functions. We demonstrate that FIM-SDE achieves robust in-context function estimation across a wide range of synthetic and real-world processes -- from canonical SDE systems (e.g., double-well dynamics or weakly perturbed Lorenz attractors) to stock price recordings and oil-price and wind-speed fluctuations -- while matching the performance of symbolic, Gaussian process and Neural SDE baselines trained on the target datasets. When finetuned to the target processes, we show that FIM-SDE consistently outperforms all these baselines.
