Memory-Efficient Gradient Unrolling for Large-Scale Bi-level Optimization
Qianli Shen, Yezhen Wang, Zhouhao Yang, Xiang Li, Haonan Wang, Yang Zhang, Jonathan Scarlett, Zhanxing Zhu, Kenji Kawaguchi
TL;DR
This work tackles the memory and scalability bottlenecks of gradient-based bi-level optimization for large-scale models. It introduces Forward Gradient Unrolling with Forward Gradient, $(\text{FG})^2\text{U}$, an unbiased stochastic meta-gradient estimator whose memory footprint does not scale with the inner unrolled depth $T$ or the meta-parameter dimension $N$, and which is highly amenable to parallelization. A theoretical convergence analysis shows an $O(\epsilon^{-1}\rho^{-1})$ rate with $\rho=b/(N-1)$, and the authors propose a practical two-phase training paradigm that first uses faster, biased methods and then applies FG^2U for accurate refinement; a zeroth-order variant $\text{FG}^2\text{U}$-ZO extends applicability to non-differentiable inner solvers. Empirically, FG^2U demonstrates superior gradient quality and memory efficiency across data condensation, meta-learning for online LM adaptation, and PDE-driven bilevel problems, highlighting its potential to scale bilevel optimization to very large models and distributed settings.
Abstract
Bi-level optimization (BO) has become a fundamental mathematical framework for addressing hierarchical machine learning problems. As deep learning models continue to grow in size, the demand for scalable bi-level optimization solutions has become increasingly critical. Traditional gradient-based bi-level optimization algorithms, due to their inherent characteristics, are ill-suited to meet the demands of large-scale applications. In this paper, we introduce $\textbf{F}$orward $\textbf{G}$radient $\textbf{U}$nrolling with $\textbf{F}$orward $\textbf{F}$radient, abbreviated as $(\textbf{FG})^2\textbf{U}$, which achieves an unbiased stochastic approximation of the meta gradient for bi-level optimization. $(\text{FG})^2\text{U}$ circumvents the memory and approximation issues associated with classical bi-level optimization approaches, and delivers significantly more accurate gradient estimates than existing large-scale bi-level optimization approaches. Additionally, $(\text{FG})^2\text{U}$ is inherently designed to support parallel computing, enabling it to effectively leverage large-scale distributed computing systems to achieve significant computational efficiency. In practice, $(\text{FG})^2\text{U}$ and other methods can be strategically placed at different stages of the training process to achieve a more cost-effective two-phase paradigm. Further, $(\text{FG})^2\text{U}$ is easy to implement within popular deep learning frameworks, and can be conveniently adapted to address more challenging zeroth-order bi-level optimization scenarios. We provide a thorough convergence analysis and a comprehensive practical discussion for $(\text{FG})^2\text{U}$, complemented by extensive empirical evaluations, showcasing its superior performance in diverse large-scale bi-level optimization tasks. Code is available at https://github.com/ShenQianli/FG2U.
