Table of Contents
Fetching ...

Efficient Bilevel Optimization with KFAC-Based Hypergradients

Disen Liao, Felix Dangel, Yaoliang Yu

Abstract

Bilevel optimization (BO) is widely applicable to many machine learning problems. Scaling BO, however, requires repeatedly computing hypergradients, which involves solving inverse Hessian-vector products (IHVPs). In practice, these operations are often approximated using crude surrogates such as one-step gradient unrolling or identity/short Neumann expansions, which discard curvature information. We build on implicit function theorem-based algorithms and propose to incorporate Kronecker-factored approximate curvature (KFAC), yielding curvature-aware hypergradients with a better performance efficiency trade-off than Conjugate Gradient (CG) or Neumann methods and consistently outperforming unrolling. We evaluate this approach across diverse tasks, including meta-learning and AI safety problems. On models up to BERT, we show that curvature information is valuable at scale, and KFAC can provide it with only modest memory and runtime overhead. Our implementation is available at https://github.com/liaodisen/NeuralBo.

Efficient Bilevel Optimization with KFAC-Based Hypergradients

Abstract

Bilevel optimization (BO) is widely applicable to many machine learning problems. Scaling BO, however, requires repeatedly computing hypergradients, which involves solving inverse Hessian-vector products (IHVPs). In practice, these operations are often approximated using crude surrogates such as one-step gradient unrolling or identity/short Neumann expansions, which discard curvature information. We build on implicit function theorem-based algorithms and propose to incorporate Kronecker-factored approximate curvature (KFAC), yielding curvature-aware hypergradients with a better performance efficiency trade-off than Conjugate Gradient (CG) or Neumann methods and consistently outperforming unrolling. We evaluate this approach across diverse tasks, including meta-learning and AI safety problems. On models up to BERT, we show that curvature information is valuable at scale, and KFAC can provide it with only modest memory and runtime overhead. Our implementation is available at https://github.com/liaodisen/NeuralBo.

Paper Structure

This paper contains 47 sections, 71 equations, 7 figures, 7 tables.

Figures (7)

  • Figure 1: Top: KFAC balances accuracy and efficiency better than Hessian/GGN-based IHVP approximations. $T$ is the number of CG iterations and $K$ the number of truncated Neumann terms. Bottom: On BERT, KFAC incurs only a small memory overhead compared to the curvature-free SAMA, while using less memory than CG/Neumann.
  • Figure 2: Results on data hypercleaning. Left two panels: existing BO algorithms improve when their IHVP solvers are replaced with IKVPs. Right two panels: KFAC-based methods compared against other non-IFT baselines, showing faster convergence in both time and iterations. Curves are truncated at the lowest test loss reached. Solid black line is trained on the validation dataset, while dashed blackline is trained on the validation dataset and the noisy dataset together.
  • Figure 3: Batch size effects on data hypercleaning. KFAC consistently achieves lower minimum test loss than CG across models and batch sizes. For linear model, CG will converge early to a worse point than KFAC.
  • Figure 4: Results on BERT data hypercleaning. We report test accuracy (y-axis) versus total training time (x-axis) on three datasets when fine-tuning 1, 7, 12 encoder layers (columns). Each marker corresponds to one hypergradient solver. The gray region highlights the Pareto frontier.
  • Figure 5: Data hypercleaning results for CNN on CIFAR-10 and linear on MNIST. For the linear model, Solid black line is the test loss reached on the cleaned dataset, while the dashed black line is trained only on validation set. For CNN, solid black line is test loss reached on validation dataset, while dashed line is trained on the validation dataset and noisy dataset.
  • ...and 2 more figures