Table of Contents
Fetching ...

Flows and Diffusions on the Neural Manifold

Daniel Saragih, Deyu Cao, Tejas Balaji

TL;DR

This work extends flow and diffusion methods to weight space by casting gradient-descent optimization as a trajectory-inference problem on the neural manifold. It introduces modular architectures (weight encoders, generative meta-models) and reward-finetuning via adjoint matching to generate task-specific weights, improve downstream initialization, and adapt to context data. The approach yields competitive in-distribution performance, faster convergence for downstream training, and a practical framework for detecting harmful covariate shifts through meta-detectron. Collectively, it demonstrates a principled, trajectory-grounded paradigm for weight-space generative modeling with potential for broader meta-learning and safety-critical applications. Key technical contributions include unifying gradient-descent dynamics under a continuity-equation framework, leveraging multi-marginal flow matching and JKOnet as practical drift approximations, and validating the benefits of conditioning and reward tuning on both standard benchmarks and covariate-shift tasks.

Abstract

Diffusion and flow-based generative models have achieved remarkable success in domains such as image synthesis, video generation, and natural language modeling. In this work, we extend these advances to weight space learning by leveraging recent techniques to incorporate structural priors derived from optimization dynamics. Central to our approach is modeling the trajectory induced by gradient descent as a trajectory inference problem. We unify several trajectory inference techniques towards matching a gradient flow, providing a theoretical framework for treating optimization paths as inductive bias. We further explore architectural and algorithmic choices, including reward fine-tuning by adjoint matching, the use of autoencoders for latent weight representation, conditioning on task-specific context data, and adopting informative source distributions such as Kaiming uniform. Experiments demonstrate that our method matches or surpasses baselines in generating in-distribution weights, improves initialization for downstream training, and supports fine-tuning to enhance performance. Finally, we illustrate a practical application in safety-critical systems: detecting harmful covariate shifts, where our method outperforms the closest comparable baseline.

Flows and Diffusions on the Neural Manifold

TL;DR

This work extends flow and diffusion methods to weight space by casting gradient-descent optimization as a trajectory-inference problem on the neural manifold. It introduces modular architectures (weight encoders, generative meta-models) and reward-finetuning via adjoint matching to generate task-specific weights, improve downstream initialization, and adapt to context data. The approach yields competitive in-distribution performance, faster convergence for downstream training, and a practical framework for detecting harmful covariate shifts through meta-detectron. Collectively, it demonstrates a principled, trajectory-grounded paradigm for weight-space generative modeling with potential for broader meta-learning and safety-critical applications. Key technical contributions include unifying gradient-descent dynamics under a continuity-equation framework, leveraging multi-marginal flow matching and JKOnet as practical drift approximations, and validating the benefits of conditioning and reward tuning on both standard benchmarks and covariate-shift tasks.

Abstract

Diffusion and flow-based generative models have achieved remarkable success in domains such as image synthesis, video generation, and natural language modeling. In this work, we extend these advances to weight space learning by leveraging recent techniques to incorporate structural priors derived from optimization dynamics. Central to our approach is modeling the trajectory induced by gradient descent as a trajectory inference problem. We unify several trajectory inference techniques towards matching a gradient flow, providing a theoretical framework for treating optimization paths as inductive bias. We further explore architectural and algorithmic choices, including reward fine-tuning by adjoint matching, the use of autoencoders for latent weight representation, conditioning on task-specific context data, and adopting informative source distributions such as Kaiming uniform. Experiments demonstrate that our method matches or surpasses baselines in generating in-distribution weights, improves initialization for downstream training, and supports fine-tuning to enhance performance. Finally, we illustrate a practical application in safety-critical systems: detecting harmful covariate shifts, where our method outperforms the closest comparable baseline.

Paper Structure

This paper contains 106 sections, 9 theorems, 93 equations, 11 figures, 14 tables, 2 algorithms.

Key Result

Theorem 1

Let $\theta_0 \sim p_0$ be initialized network parameters and the loss ${\mathcal{L}}$ is $C^1$ in $\theta$. If $(\theta_t)_{t \geq 0}$ is the gradient descent curve, we have $p_t = \mathrm{Law}(\theta_t)$ with

Figures (11)

  • Figure 1: Example unconditional pipeline.(1) Base model pre-training, shown here on MNIST, producing checkpoints across epochs. (2) Optional: variational autoencoder training with a weight-space reconstruction objective. (3) Generative meta-model training; here we illustrate unconditional N$\mathcal{M}$-CFM w/ (trained) VAE (our default N$\mathcal{M}$-CFM is on weight space directly) using the weight initialization from (1) as $p_0$. (4) Optional: reward fine-tuning via adjoint matching where $r(\boldsymbol{\cdot}) = -L_{\text{clf}}({\bm{X}}_{\mathrm{FashionMNIST}}; \, \boldsymbol{\cdot})$, steering the trained meta-model towards generating FashionMNIST classifiers.
  • Figure 2: Base model validation loss over the course of inference for various N$\mathcal{M}$ methods. The plots were computed on 20 out of 100 intermediate timepoints for MMFM and CFM, but restricted by design to $k$ timepoints for JKO($k$). MMFM_k refers to MMFM with $k$ intermediate marginal distributions (distributions in addition to $p_0$ and $p_1$) and likewise for JKO.
  • Figure 3: Mean $W_1$-distance ($\times 100$) between reference and generated intermediate marginals over 5 seeds of unconditional generation. The horizontal axis corresponds to increasing indices of the intermediate marginals, i.e. $k$ in $p_{t_k}$ where $t_0 = 0, \, t_6 = 1$. The plots also show the effect of using a Gaussian prior with MMFM (denoted N$\mathcal{M}$-MMFM-gauss), excluding N$\mathcal{M}$-CFM-gauss due to its large $W_1$ deviation.
  • Figure 4: Plots illustrating how AUROC and $\ell_{cdc}$ evolves over meta-detectron training iterations for CIFAR10 and Camelyon17 when $|{\bm{\mathsfit{Q}}}|=20$. See App. \ref{['app:meta-detectron-training']} for more figures.
  • Figure 5: Mock visualization of VAE latent space in D2NWG. Although the interpolant (dotted line connecting $z_0$ to $z_1$) does not follow the reference trajectory $p_0 \to p_1 \to \dots \to p_{data}$, the points on the line reside within the data manifold.
  • ...and 6 more figures

Theorems & Definitions (21)

  • Theorem 1: name=Informal; follows Ch. 8.3 of santambrogio2015optimal
  • Proposition 1: name=Prop. A.1 of neklyudov2023actionmatching
  • Theorem 2
  • proof : Proof of Theorem \ref{['thm:ce']}
  • proof
  • proof
  • Lemma 1
  • proof
  • proof : Proof of Theorem \ref{['thm:mmfm-action-gap']}
  • proof
  • ...and 11 more