Latent SDEs on Homogeneous Spaces
Sebastian Zeng, Florian Graf, Roland Kwitt
TL;DR
The paper proposes learning latent stochastic dynamics restricted to homogeneous spaces, realized by Lie group actions, with a primary focus on SDEs on the unit sphere. By employing a one-step geometric Euler–Maruyama solver and discretize-then-optimize gradients, the approach yields efficient variational inference with a notably simple KL divergence on the sphere. Empirically, the method achieves competitive or state-of-the-art results on interpolation and per-time-point classification/regression across multiple datasets, while maintaining favorable runtime characteristics compared to more flexible neural SDEs. The work highlights a principled, geometry-friendly alternative to fully general neural SDEs, balancing model capacity and tractable training, and opens opportunities to extend to other homogeneous spaces and more advanced numerical schemes.
Abstract
We consider the problem of variational Bayesian inference in a latent variable model where a (possibly complex) observed stochastic process is governed by the solution of a latent stochastic differential equation (SDE). Motivated by the challenges that arise when trying to learn an (almost arbitrary) latent neural SDE from data, such as efficient gradient computation, we take a step back and study a specific subclass instead. In our case, the SDE evolves on a homogeneous latent space and is induced by stochastic dynamics of the corresponding (matrix) Lie group. In learning problems, SDEs on the unit n-sphere are arguably the most relevant incarnation of this setup. Notably, for variational inference, the sphere not only facilitates using a truly uninformative prior, but we also obtain a particularly simple and intuitive expression for the Kullback-Leibler divergence between the approximate posterior and prior process in the evidence lower bound. Experiments demonstrate that a latent SDE of the proposed type can be learned efficiently by means of an existing one-step geometric Euler-Maruyama scheme. Despite restricting ourselves to a less rich class of SDEs, we achieve competitive or even state-of-the-art results on various time series interpolation/classification problems.
