Table of Contents
Fetching ...

Gradient Flow Matching for Learning Update Dynamics in Neural Network Training

Xiao Shou, Yanna Ding, Jianxi Gao

TL;DR

This paper introduces Gradient Flow Matching (GFM), a continuous-time framework that models neural network training as an optimizer-aware dynamical system using vector fields learned via conditional flow matching. By explicitly incorporating gradient-based update dynamics, GFM can forecast final converged weights from partial training sequences and extend to momentum and adaptive optimizers. Empirical results show GFM outperforms standard sequence models (e.g., LSTM) and approaches Transformer performance across synthetic and real-world settings (including CIFAR-10), while generalizing across architectures. This approach offers a principled, scalable method for predicting optimization trajectories, with potential to accelerate convergence prediction and optimization research by bridging continuous-time modeling with practical forecasting tasks.

Abstract

Training deep neural networks remains computationally intensive due to the itera2 tive nature of gradient-based optimization. We propose Gradient Flow Matching (GFM), a continuous-time modeling framework that treats neural network training as a dynamical system governed by learned optimizer-aware vector fields. By leveraging conditional flow matching, GFM captures the underlying update rules of optimizers such as SGD, Adam, and RMSprop, enabling smooth extrapolation of weight trajectories toward convergence. Unlike black-box sequence models, GFM incorporates structural knowledge of gradient-based updates into the learning objective, facilitating accurate forecasting of final weights from partial training sequences. Empirically, GFM achieves forecasting accuracy that is competitive with Transformer-based models and significantly outperforms LSTM and other classical baselines. Furthermore, GFM generalizes across neural architectures and initializations, providing a unified framework for studying optimization dynamics and accelerating convergence prediction.

Gradient Flow Matching for Learning Update Dynamics in Neural Network Training

TL;DR

This paper introduces Gradient Flow Matching (GFM), a continuous-time framework that models neural network training as an optimizer-aware dynamical system using vector fields learned via conditional flow matching. By explicitly incorporating gradient-based update dynamics, GFM can forecast final converged weights from partial training sequences and extend to momentum and adaptive optimizers. Empirical results show GFM outperforms standard sequence models (e.g., LSTM) and approaches Transformer performance across synthetic and real-world settings (including CIFAR-10), while generalizing across architectures. This approach offers a principled, scalable method for predicting optimization trajectories, with potential to accelerate convergence prediction and optimization research by bridging continuous-time modeling with practical forecasting tasks.

Abstract

Training deep neural networks remains computationally intensive due to the itera2 tive nature of gradient-based optimization. We propose Gradient Flow Matching (GFM), a continuous-time modeling framework that treats neural network training as a dynamical system governed by learned optimizer-aware vector fields. By leveraging conditional flow matching, GFM captures the underlying update rules of optimizers such as SGD, Adam, and RMSprop, enabling smooth extrapolation of weight trajectories toward convergence. Unlike black-box sequence models, GFM incorporates structural knowledge of gradient-based updates into the learning objective, facilitating accurate forecasting of final weights from partial training sequences. Empirically, GFM achieves forecasting accuracy that is competitive with Transformer-based models and significantly outperforms LSTM and other classical baselines. Furthermore, GFM generalizes across neural architectures and initializations, providing a unified framework for studying optimization dynamics and accelerating convergence prediction.

Paper Structure

This paper contains 48 sections, 9 equations, 4 figures, 9 tables, 1 algorithm.

Figures (4)

  • Figure 1: Visualization of the evolving weight distribution over training epochs. Each panel shows the probability density $p_t(\mathbf{w})$ at a given epoch $t$, with red dots indicating the actual weight positions $\mathbf{w}_0, \dots, \mathbf{w}_7$. As training progresses, weights move from a broad initialization distribution toward a more concentrated region near convergence. The final state $\mathbf{w}_7$ approximates the optimal weight $\mathbf{w}_*$.
  • Figure 2: Forecasted optimization trajectories produced by GFM for different optimizers (rows) and time steps (columns), shown for 20 test trajectories starting from $t_n$. Each row visualizes the evolution of weights from initialization to convergence, conditioned on the first five steps. GFM successfully learns smooth transitions and captures the distinct dynamics characteristic of each optimizer.
  • Figure 3: Loss trajectories of 2-layer MLPs trained with different optimizers. Each blue curve shows task-wise training loss over epochs, and red stars mark flow-matched predictions.
  • Figure 4: Optimization trajectories of weight parameters from initialization to convergence for five optimizers. Each plot visualizes 50 trajectories for linear regression tasks under Gaussian weight sampling.