Table of Contents
Fetching ...

In-Context Learning of Stochastic Differential Equations with Foundation Inference Models

Patrick Seifner, Kostadin Cvejoski, David Berghaus, Cesar Ojeda, Ramses J. Sanchez

TL;DR

This work tackles data-driven discovery of drift and diffusion in stochastic differential equations by introducing FIM-SDE, a foundation inference model pretrained on synthetic SDEs with polynomial drift and diffusion. Leveraging a transformer-based neural-operator architecture, FIM-SDE achieves zero-shot in-context estimation of $\mathbf{f}$ and $\mathbf{G}$ for low-dimensional SDEs and can be rapidly finetuned to target data, often outperforming symbolic, GP, and neural SDE baselines after only a small amount of additional training. The approach demonstrates robust generalization across canonical systems and real-world datasets, with notable efficiency gains (e.g., up to 50x faster fine-tuning) and strong performance under irregular sampling and noise. This foundation-model–like solution has potential to accelerate automated scientific discovery by enabling quick, data-efficient inference of dynamical laws from limited or noisy observations.

Abstract

Stochastic differential equations (SDEs) describe dynamical systems where deterministic flows, governed by a drift function, are superimposed with random fluctuations, dictated by a diffusion function. The accurate estimation (or discovery) of these functions from data is a central problem in machine learning, with wide application across the natural and social sciences. Yet current solutions either rely heavily on prior knowledge of the dynamics or involve intricate training procedures. We introduce FIM-SDE (Foundation Inference Model for SDEs), a pretrained recognition model that delivers accurate in-context (or zero-shot) estimation of the drift and diffusion functions of low-dimensional SDEs, from noisy time series data, and allows rapid finetuning to target datasets. Leveraging concepts from amortized inference and neural operators, we (pre)train FIM-SDE in a supervised fashion to map a large set of noisy, discretely observed SDE paths onto the space of drift and diffusion functions. We demonstrate that FIM-SDE achieves robust in-context function estimation across a wide range of synthetic and real-world processes -- from canonical SDE systems (e.g., double-well dynamics or weakly perturbed Lorenz attractors) to stock price recordings and oil-price and wind-speed fluctuations -- while matching the performance of symbolic, Gaussian process and Neural SDE baselines trained on the target datasets. When finetuned to the target processes, we show that FIM-SDE consistently outperforms all these baselines.

In-Context Learning of Stochastic Differential Equations with Foundation Inference Models

TL;DR

This work tackles data-driven discovery of drift and diffusion in stochastic differential equations by introducing FIM-SDE, a foundation inference model pretrained on synthetic SDEs with polynomial drift and diffusion. Leveraging a transformer-based neural-operator architecture, FIM-SDE achieves zero-shot in-context estimation of and for low-dimensional SDEs and can be rapidly finetuned to target data, often outperforming symbolic, GP, and neural SDE baselines after only a small amount of additional training. The approach demonstrates robust generalization across canonical systems and real-world datasets, with notable efficiency gains (e.g., up to 50x faster fine-tuning) and strong performance under irregular sampling and noise. This foundation-model–like solution has potential to accelerate automated scientific discovery by enabling quick, data-efficient inference of dynamical laws from limited or noisy observations.

Abstract

Stochastic differential equations (SDEs) describe dynamical systems where deterministic flows, governed by a drift function, are superimposed with random fluctuations, dictated by a diffusion function. The accurate estimation (or discovery) of these functions from data is a central problem in machine learning, with wide application across the natural and social sciences. Yet current solutions either rely heavily on prior knowledge of the dynamics or involve intricate training procedures. We introduce FIM-SDE (Foundation Inference Model for SDEs), a pretrained recognition model that delivers accurate in-context (or zero-shot) estimation of the drift and diffusion functions of low-dimensional SDEs, from noisy time series data, and allows rapid finetuning to target datasets. Leveraging concepts from amortized inference and neural operators, we (pre)train FIM-SDE in a supervised fashion to map a large set of noisy, discretely observed SDE paths onto the space of drift and diffusion functions. We demonstrate that FIM-SDE achieves robust in-context function estimation across a wide range of synthetic and real-world processes -- from canonical SDE systems (e.g., double-well dynamics or weakly perturbed Lorenz attractors) to stock price recordings and oil-price and wind-speed fluctuations -- while matching the performance of symbolic, Gaussian process and Neural SDE baselines trained on the target datasets. When finetuned to the target processes, we show that FIM-SDE consistently outperforms all these baselines.

Paper Structure

This paper contains 50 sections, 44 equations, 6 figures, 12 tables.

Figures (6)

  • Figure 1: Left: Three-step pretraining strategy for SDE discovery problem: (1) sample target drift and diffusion $\mathbf{f}$, $\mathbf{G}$ from the prior distribution $p_{prior}$; (2) simulate and corrupt SDE paths, with the circles denoting noisy observations; and (3) compute neural estimators $\mathbf{\hat{f}}_\theta, \mathbf{\hat{G}}_\theta$, and match them to the target functions. Right: Foundation Inference Model for SDEs (schematic representation). The input (context) consists of $K(L-1)$ tuples of the form $(\mathbf{y}, \Delta \mathbf{y}, \Delta \mathbf{y}^2, \Delta \tau)$ that are projected by the linear $\phi$ layers. The result is processed by a Transformer encoder $\Psi$ that returns the Context Matrix (Eq. \ref{['eq:context-matrix']}). This matrix is used as keys and values to $M$Functional Attention layers: $\psi_{\mathbf{f}}$, $\psi_{\mathbf{G}}$ and $\psi_U$. The input queries are the (embedded) location $\mathbf{x}$ where we evaluate the estimated functions.
  • Figure 2: Drift and diffusion estimation for two canonical SDEs with state-dependent diffusion in the uncorrupted setting ($\rho=0.0$, $\Delta\tau = 0.002$). Left: Double-well system (Eq. \ref{['eq:double-well']}). Right: Synthetic 2D system (Eqs. \ref{['eq:wang1']} and \ref{['eq:wang2']}). FIM-SDE infers the target functions in-context (i.e., zero-shot mode), showing excellent agreement with the ground truth.
  • Figure 3: Sample paths (first three panels) and MMD convergence (right) for the Lorenz experiment. FIM-SDE in zero-shot mode already captures the global dynamics, though errors accumulate over time. With finetuning, FIM-SDE quickly refines its estimates of the dynamics, requiring only a few iterations and converging much faster than LatentSDE.
  • Figure 4: Paths sampled from SDEs inferred by FIM-SDE (dark blue) and by FIM-SDE finetuned on oil (left) and stock price data (right). Finetuning noticeably improves sample path quality in the oil dataset, while the strong zero-shot performance on stock prices is retained.
  • Figure 5: Drift and diffusion function estimation from FIM-SDE and baselines in all canonical SDE systems. The vector fields are estimated from observations with setup $\rho=0.0$ and $\Delta \tau=0.002$. The sample paths from FIM-SDE resemble the ground-truth paths closely.
  • ...and 1 more figures