Memory-Reduced Meta-Learning with Guaranteed Convergence
Honglin Yang, Ji Ma, Xiao Yu
TL;DR
This paper tackles the high memory cost of optimization-based meta-learning caused by backward differentiation through lower-level updates. It introduces a memory-reduced framework that estimates the hypergradient via a Hessian-inverse-vector product using a conjugate-gradient approach, avoiding the need to store historical lower-level parameters/gradients. The authors establish sublinear convergence for stochastic meta-learning with error decaying as $O(1/|\mathcal{B}|)$ and, in the deterministic setting, convergence to an exact solution with complexity $O(\epsilon^{-1})$, while reducing hypergradient computation to $O(\max\{p,q\})$ per iteration. Empirical results on CIFAR-FS, FC100, miniImageNet, and tieredImageNet show substantial memory savings (over 50% vs baselines) and competitive accuracy, with faster wallclock convergence than prevailing methods.
Abstract
The optimization-based meta-learning approach is gaining increased traction because of its unique ability to quickly adapt to a new task using only small amounts of data. However, existing optimization-based meta-learning approaches, such as MAML, ANIL and their variants, generally employ backpropagation for upper-level gradient estimation, which requires using historical lower-level parameters/gradients and thus increases computational and memory overhead in each iteration. In this paper, we propose a meta-learning algorithm that can avoid using historical parameters/gradients and significantly reduce memory costs in each iteration compared to existing optimization-based meta-learning approaches. In addition to memory reduction, we prove that our proposed algorithm converges sublinearly with the iteration number of upper-level optimization, and the convergence error decays sublinearly with the batch size of sampled tasks. In the specific case in terms of deterministic meta-learning, we also prove that our proposed algorithm converges to an exact solution. Moreover, we quantify that the computational complexity of the algorithm is on the order of $\mathcal{O}(ε^{-1})$, which matches existing convergence results on meta-learning even without using any historical parameters/gradients. Experimental results on meta-learning benchmarks confirm the efficacy of our proposed algorithm.
