Convergence of SGD for Training Neural Networks with Sliced Wasserstein Losses
Eloi Tanguy
TL;DR
This work provides the first rigorous convergence guarantees for SGD when training neural networks with Sliced Wasserstein losses by framing the problem within non-smooth, non-convex optimization and using Clarke differential theory. It shows that interpolated fixed-step SGD trajectories approximate sub-gradient flows of the population loss $F$ and, under stronger noised-projected dynamics, converge to generalized critical points, thereby explaining practical convergence observations. The analysis hinges on piecewise smooth network maps, Lipschitz regularity, and the path-differentiable structure of the SW loss; results extend to $p$-SW orders under additional assumptions. While illuminating, the theory currently requires discrete input measures for the strongest results, and future work could generalize to non-discrete inputs and learned projections, as well as explore connections to SW flows and other OT-based losses.
Abstract
Optimal Transport has sparked vivid interest in recent years, in particular thanks to the Wasserstein distance, which provides a geometrically sensible and intuitive way of comparing probability measures. For computational reasons, the Sliced Wasserstein (SW) distance was introduced as an alternative to the Wasserstein distance, and has seen uses for training generative Neural Networks (NNs). While convergence of Stochastic Gradient Descent (SGD) has been observed practically in such a setting, there is to our knowledge no theoretical guarantee for this observation. Leveraging recent works on convergence of SGD on non-smooth and non-convex functions by Bianchi et al. (2022), we aim to bridge that knowledge gap, and provide a realistic context under which fixed-step SGD trajectories for the SW loss on NN parameters converge. More precisely, we show that the trajectories approach the set of (sub)-gradient flow equations as the step decreases. Under stricter assumptions, we show a much stronger convergence result for noised and projected SGD schemes, namely that the long-run limits of the trajectories approach a set of generalised critical points of the loss function.
