Efficient Training of Neural SDEs Using Stochastic Optimal Control
Rembert Daems, Manfred Opper, Guillaume Crevecoeur, Tolga Birdal
TL;DR
The paper addresses the computational bottleneck of variational inference for neural SDEs by introducing a stochastic optimal control framework that decomposes the variational posterior control into a linear, closed-form term and a learnable nonlinear residual. For linear, Gaussian priors, the optimal control is derived in closed form as $u(x,t)=\sigma(t)^T \nabla_x \log \mathcal{N}(\mathbf O; \mathbf m_x, \mathbf C+\Sigma_0)$, which reduces further to a simple expression when $p(\mathbf X(T)|x)$ is Gaussian. The nonlinear residual is captured by neural networks, enabling expressive modeling without sacrificing initialization and convergence speed. Empirical results on BM and MA-fBM data show faster convergence and lower loss with the hybrid approach, highlighting a practical path to efficient uncertainty-aware time-series modeling with neural SDEs.
Abstract
We present a hierarchical, control theory inspired method for variational inference (VI) for neural stochastic differential equations (SDEs). While VI for neural SDEs is a promising avenue for uncertainty-aware reasoning in time-series, it is computationally challenging due to the iterative nature of maximizing the ELBO. In this work, we propose to decompose the control term into linear and residual non-linear components and derive an optimal control term for linear SDEs, using stochastic optimal control. Modeling the non-linear component by a neural network, we show how to efficiently train neural SDEs without sacrificing their expressive power. Since the linear part of the control term is optimal and does not need to be learned, the training is initialized at a lower cost and we observe faster convergence.
