Table of Contents
Fetching ...

Functional Bilevel Optimization for Machine Learning

Ieva Petrulionyte, Julien Mairal, Michael Arbel

TL;DR

This work reframes bilevel optimization in machine learning as a functional problem, minimizing the inner objective over a function space $\mathcal{H}$ instead of over neural network parameters. By leveraging functional implicit differentiation in $L_2$ spaces, it derives a stable total gradient formula $\nabla\mathcal{F}(\omega)=g_\omega + B_\omega a_\omega^\star$ and introduces FuncID, a scalable algorithm that learns both the inner prediction and adjoint functions with neural nets. The paper proves differentiability and convergence guarantees under mild assumptions and demonstrates practical benefits on instrumental-variable regression (2SLS) and model-based reinforcement learning (CartPole), showing improved stability, faster convergence, and competitive or superior performance versus parametric bilevel methods. The functional view enables using deep nets as inner predictors while mitigating the ill-posedness and ambiguity that arise from multiple inner solutions, with open avenues to extend to RKHS, non-smooth objectives, and broader ML tasks.

Abstract

In this paper, we introduce a new functional point of view on bilevel optimization problems for machine learning, where the inner objective is minimized over a function space. These types of problems are most often solved by using methods developed in the parametric setting, where the inner objective is strongly convex with respect to the parameters of the prediction function. The functional point of view does not rely on this assumption and notably allows using over-parameterized neural networks as the inner prediction function. We propose scalable and efficient algorithms for the functional bilevel optimization problem and illustrate the benefits of our approach on instrumental regression and reinforcement learning tasks.

Functional Bilevel Optimization for Machine Learning

TL;DR

This work reframes bilevel optimization in machine learning as a functional problem, minimizing the inner objective over a function space instead of over neural network parameters. By leveraging functional implicit differentiation in spaces, it derives a stable total gradient formula and introduces FuncID, a scalable algorithm that learns both the inner prediction and adjoint functions with neural nets. The paper proves differentiability and convergence guarantees under mild assumptions and demonstrates practical benefits on instrumental-variable regression (2SLS) and model-based reinforcement learning (CartPole), showing improved stability, faster convergence, and competitive or superior performance versus parametric bilevel methods. The functional view enables using deep nets as inner predictors while mitigating the ill-posedness and ambiguity that arise from multiple inner solutions, with open avenues to extend to RKHS, non-smooth objectives, and broader ML tasks.

Abstract

In this paper, we introduce a new functional point of view on bilevel optimization problems for machine learning, where the inner objective is minimized over a function space. These types of problems are most often solved by using methods developed in the parametric setting, where the inner objective is strongly convex with respect to the parameters of the prediction function. The functional point of view does not rely on this assumption and notably allows using over-parameterized neural networks as the inner prediction function. We propose scalable and efficient algorithms for the functional bilevel optimization problem and illustrate the benefits of our approach on instrumental regression and reinforcement learning tasks.
Paper Structure (60 sections, 16 theorems, 140 equations, 8 figures, 2 tables, 4 algorithms)

This paper contains 60 sections, 16 theorems, 140 equations, 8 figures, 2 tables, 4 algorithms.

Key Result

Theorem 2.1

Consider problem (def:funcBO) and assume that: Then, $\omega \mapsto h^\star_\omega$ is uniquely defined and is Fréchet differentiable with a Jacobian $\partial_{\omega}h^\star_{\omega}$ given by:

Figures (8)

  • Figure 1: Parametric vs functional approaches for solving \ref{['def:funcBO']} by implicit differentiation.
  • Figure 2: Performance metrics for Instrumental Variable (IV) regression. All results are averaged over 20 runs with 5000 training samples and 588 test samples. ( Left) box plot of the test loss, with the dashed black line indicating the mean test error. ( Middle) outer loss vs training iterations, (Right) inner loss vs training iterations. The bold lines in the middle and right plots indicate the mean loss, the shaded area corresponds to standard deviation.
  • Figure 3: Average reward on an evaluation environment vs. training iterations on the CartPole task. ( Left) Well-specified model with 32 hidden units. ( Right) Misspecified model with 3 hidden units. Both plots show mean reward over 10 runs where the shaded region is the 95% confidence interval.
  • Figure 4: Memory and time comparison of a single total gradient approximation using FuncID vs AID. ( Left) Memory usage ratio of FuncID over AID vs inner model parameter dimension $p_{in}$, for various values of the output dimension $d_{v}$. (Right) Time ratio of FuncID over AID vs inner model parameter dimension $p_{in}$ averaged over several values of $d_v$ and $10^4$ evaluations. The continuous lines are experimental results obtained using a JAX implementation jax2018github running on a GPU. The dashed lines correspond to theoretical estimates obtained using the algorithmic costs given in \ref{['tab:table_complex']} with $\gamma=12, \delta=2$ for time, and the constant factors in the memory cost fitted to the data.
  • Figure 5: The causal relationships between all variables in an Instrumental Variable (IV) causal graph, where $t$ is the treatment variable (dsprites image), $o$ is the outcome (label in $\mathbb{R}$), $x$ is the instrument and $\epsilon$ is the unobserved confounder
  • ...and 3 more figures

Theorems & Definitions (34)

  • Theorem 2.1: Functional implicit differentiation
  • Proposition 2.2: Functional adjoint sensitivity
  • Proposition 2.3: Functional Adjoint sensitivity in $L_2$ spaces.
  • Theorem 3.1
  • Definition C.1
  • Proposition C.2
  • proof
  • proof : Proof of \ref{['thm:implicit_function']}
  • proof : Proof of \ref{['prop:implicit_diff']}
  • Proposition D.1
  • ...and 24 more