Table of Contents
Fetching ...

Bridging Training and Merging Through Momentum-Aware Optimization

Alireza Moayedikia, Alicia Troncoso

TL;DR

UMTAM tackles the dual challenge of memory-efficient training and principled multi-task model merging by preserving a factorized momentum representation and curvature statistics throughout training and reusing them during merging. It introduces dual momentum factorization with error feedback and factorized second-order statistics to achieve memory efficiency, while curvature-aware merging uses task-saliency and sign election to resolve conflicts and balance contributions. Theoretical guarantees cover non-convex convergence, preconditioner quality, and generalization bounds for merged models, complemented by empirical results showing curvature-aware saliency outperforms magnitude-based pruning and that unified training-merging yields strong multi-task performance with rank-invariant convergence. The framework eliminates redundant Fisher computations, maintains competitive training dynamics with modest overhead, and demonstrates practical benefits for combining task-specific experts in NLP settings, with promising avenues for continual learning and federated learning.

Abstract

Training large neural networks and merging task-specific models both exploit low-rank structure and require parameter importance estimation, yet these challenges have been pursued in isolation. Current workflows compute curvature information during training, discard it, then recompute similar information for merging -- wasting computation and discarding valuable trajectory data. We introduce a unified framework that maintains factorized momentum and curvature statistics during training, then reuses this information for geometry-aware model composition. The proposed method achieves memory efficiency comparable to state-of-the-art approaches while accumulating task saliency scores that enable curvature-aware merging without post-hoc Fisher computation. We establish convergence guarantees for non-convex objectives with approximation error bounded by gradient singular value decay. On natural language understanding benchmarks, curvature-aware parameter selection outperforms magnitude-only baselines across all sparsity levels, with multi-task merging improving over strong baselines. The proposed framework exhibits rank-invariant convergence and superior hyperparameter robustness compared to existing low-rank optimizers. By treating the optimization trajectory as a reusable asset rather than discarding it, our approach eliminates redundant computation while enabling more principled model composition.

Bridging Training and Merging Through Momentum-Aware Optimization

TL;DR

UMTAM tackles the dual challenge of memory-efficient training and principled multi-task model merging by preserving a factorized momentum representation and curvature statistics throughout training and reusing them during merging. It introduces dual momentum factorization with error feedback and factorized second-order statistics to achieve memory efficiency, while curvature-aware merging uses task-saliency and sign election to resolve conflicts and balance contributions. Theoretical guarantees cover non-convex convergence, preconditioner quality, and generalization bounds for merged models, complemented by empirical results showing curvature-aware saliency outperforms magnitude-based pruning and that unified training-merging yields strong multi-task performance with rank-invariant convergence. The framework eliminates redundant Fisher computations, maintains competitive training dynamics with modest overhead, and demonstrates practical benefits for combining task-specific experts in NLP settings, with promising avenues for continual learning and federated learning.

Abstract

Training large neural networks and merging task-specific models both exploit low-rank structure and require parameter importance estimation, yet these challenges have been pursued in isolation. Current workflows compute curvature information during training, discard it, then recompute similar information for merging -- wasting computation and discarding valuable trajectory data. We introduce a unified framework that maintains factorized momentum and curvature statistics during training, then reuses this information for geometry-aware model composition. The proposed method achieves memory efficiency comparable to state-of-the-art approaches while accumulating task saliency scores that enable curvature-aware merging without post-hoc Fisher computation. We establish convergence guarantees for non-convex objectives with approximation error bounded by gradient singular value decay. On natural language understanding benchmarks, curvature-aware parameter selection outperforms magnitude-only baselines across all sparsity levels, with multi-task merging improving over strong baselines. The proposed framework exhibits rank-invariant convergence and superior hyperparameter robustness compared to existing low-rank optimizers. By treating the optimization trajectory as a reusable asset rather than discarding it, our approach eliminates redundant computation while enabling more principled model composition.

Paper Structure

This paper contains 21 sections, 10 theorems, 74 equations, 8 figures, 11 tables, 2 algorithms.

Key Result

Theorem 3.1

Let $f$ be $\mu$-strongly convex and $L$-smooth with $L_{nuc}$-nuclear norm smoothness. Under UMTAM with constant learning rate $\eta \leq \min\{1/L, 2\mu/L^2\}$ and rank $r$, we have: where $\sigma_{r+1}$ is the $(r+1)$-th singular value of the expected gradient covariance.

Figures (8)

  • Figure 1: UMTAM framework overview. The training phase (left, blue) maintains factorized momentum and curvature statistics while accumulating task-specific saliency scores. These statistics (center, green) are preserved after training rather than discarded. The merging phase (right, orange) reuses this information for curvature-aware pruning, conflict resolution through importance-weighted sign election, and geometry-respecting parameter aggregation. Dashed green arrows indicate the flow of preserved statistics from training to merging, eliminating the redundant Fisher computation required by conventional sequential approaches.
  • Figure 2: Validation loss comparison across ranks and learning rates. UMTAM achieves lower loss and exhibits greater stability across hyperparameter settings.
  • Figure 3: Pruning performance across four GLUE tasks with distinct linguistic characteristics
  • Figure 4: Sparsity sensitivity comparison. UMTAM consistently outperforms TIES across all retention levels, with the largest advantage at aggressive sparsity ($k=5\%$: +6.4%).
  • Figure 5: Rank-invariance comparison across three learning rate regimes. UMTAM (blue) exhibits near-horizontal trajectories across all ranks, with variation $<0.1\%$ at its optimal $\eta = 3 \times 10^{-5}$. MoFaSGD (red) shows dramatic rank-dependence at $\eta = 1 \times 10^{-4}$, with losses ranging from 2.129 to 3.800---a 78.8% degradation. GaLore (green) follows a similar instability pattern at higher ranks.
  • ...and 3 more figures

Theorems & Definitions (22)

  • Definition 1: Nuclear Norm Smoothness
  • Definition 2: Stable Rank
  • Theorem 3.1: UMTAM Convergence for Strongly Convex Functions
  • proof
  • Theorem 3.2: UMTAM Convergence for Non-convex Functions
  • proof
  • Lemma 3.3: Preconditioner Approximation Quality
  • proof
  • Theorem 3.4: Merging Quality Guarantee
  • proof
  • ...and 12 more