Bridging Training and Merging Through Momentum-Aware Optimization
Alireza Moayedikia, Alicia Troncoso
TL;DR
UMTAM tackles the dual challenge of memory-efficient training and principled multi-task model merging by preserving a factorized momentum representation and curvature statistics throughout training and reusing them during merging. It introduces dual momentum factorization with error feedback and factorized second-order statistics to achieve memory efficiency, while curvature-aware merging uses task-saliency and sign election to resolve conflicts and balance contributions. Theoretical guarantees cover non-convex convergence, preconditioner quality, and generalization bounds for merged models, complemented by empirical results showing curvature-aware saliency outperforms magnitude-based pruning and that unified training-merging yields strong multi-task performance with rank-invariant convergence. The framework eliminates redundant Fisher computations, maintains competitive training dynamics with modest overhead, and demonstrates practical benefits for combining task-specific experts in NLP settings, with promising avenues for continual learning and federated learning.
Abstract
Training large neural networks and merging task-specific models both exploit low-rank structure and require parameter importance estimation, yet these challenges have been pursued in isolation. Current workflows compute curvature information during training, discard it, then recompute similar information for merging -- wasting computation and discarding valuable trajectory data. We introduce a unified framework that maintains factorized momentum and curvature statistics during training, then reuses this information for geometry-aware model composition. The proposed method achieves memory efficiency comparable to state-of-the-art approaches while accumulating task saliency scores that enable curvature-aware merging without post-hoc Fisher computation. We establish convergence guarantees for non-convex objectives with approximation error bounded by gradient singular value decay. On natural language understanding benchmarks, curvature-aware parameter selection outperforms magnitude-only baselines across all sparsity levels, with multi-task merging improving over strong baselines. The proposed framework exhibits rank-invariant convergence and superior hyperparameter robustness compared to existing low-rank optimizers. By treating the optimization trajectory as a reusable asset rather than discarding it, our approach eliminates redundant computation while enabling more principled model composition.
