Flows and Diffusions on the Neural Manifold
Daniel Saragih, Deyu Cao, Tejas Balaji
TL;DR
This work extends flow and diffusion methods to weight space by casting gradient-descent optimization as a trajectory-inference problem on the neural manifold. It introduces modular architectures (weight encoders, generative meta-models) and reward-finetuning via adjoint matching to generate task-specific weights, improve downstream initialization, and adapt to context data. The approach yields competitive in-distribution performance, faster convergence for downstream training, and a practical framework for detecting harmful covariate shifts through meta-detectron. Collectively, it demonstrates a principled, trajectory-grounded paradigm for weight-space generative modeling with potential for broader meta-learning and safety-critical applications. Key technical contributions include unifying gradient-descent dynamics under a continuity-equation framework, leveraging multi-marginal flow matching and JKOnet as practical drift approximations, and validating the benefits of conditioning and reward tuning on both standard benchmarks and covariate-shift tasks.
Abstract
Diffusion and flow-based generative models have achieved remarkable success in domains such as image synthesis, video generation, and natural language modeling. In this work, we extend these advances to weight space learning by leveraging recent techniques to incorporate structural priors derived from optimization dynamics. Central to our approach is modeling the trajectory induced by gradient descent as a trajectory inference problem. We unify several trajectory inference techniques towards matching a gradient flow, providing a theoretical framework for treating optimization paths as inductive bias. We further explore architectural and algorithmic choices, including reward fine-tuning by adjoint matching, the use of autoencoders for latent weight representation, conditioning on task-specific context data, and adopting informative source distributions such as Kaiming uniform. Experiments demonstrate that our method matches or surpasses baselines in generating in-distribution weights, improves initialization for downstream training, and supports fine-tuning to enhance performance. Finally, we illustrate a practical application in safety-critical systems: detecting harmful covariate shifts, where our method outperforms the closest comparable baseline.
