Table of Contents
Fetching ...

Understanding the Curse of Unrolling

Sheheryar Mehmood, Florian Knoll, Peter Ochs

TL;DR

It is shown that truncating early iterations of the derivative computation mitigates the curse of unrolling while simultaneously reducing memory requirements, and it is demonstrated that warm-starting in bilevel optimization naturally induces an implicit form of truncation, providing a practical remedy.

Abstract

Algorithm unrolling is ubiquitous in machine learning, particularly in hyperparameter optimization and meta-learning, where Jacobians of solution mappings are computed by differentiating through iterative algorithms. Although unrolling is known to yield asymptotically correct Jacobians under suitable conditions, recent work has shown that the derivative iterates may initially diverge from the true Jacobian, a phenomenon known as the curse of unrolling. In this work, we provide a non-asymptotic analysis that explains the origin of this behavior and identifies the algorithmic factors that govern it. We show that truncating early iterations of the derivative computation mitigates the curse while simultaneously reducing memory requirements. Finally, we demonstrate that warm-starting in bilevel optimization naturally induces an implicit form of truncation, providing a practical remedy. Our theoretical findings are supported by numerical experiments on representative examples.

Understanding the Curse of Unrolling

TL;DR

It is shown that truncating early iterations of the derivative computation mitigates the curse of unrolling while simultaneously reducing memory requirements, and it is demonstrated that warm-starting in bilevel optimization naturally induces an implicit form of truncation, providing a practical remedy.

Abstract

Algorithm unrolling is ubiquitous in machine learning, particularly in hyperparameter optimization and meta-learning, where Jacobians of solution mappings are computed by differentiating through iterative algorithms. Although unrolling is known to yield asymptotically correct Jacobians under suitable conditions, recent work has shown that the derivative iterates may initially diverge from the true Jacobian, a phenomenon known as the curse of unrolling. In this work, we provide a non-asymptotic analysis that explains the origin of this behavior and identifies the algorithmic factors that govern it. We show that truncating early iterations of the derivative computation mitigates the curse while simultaneously reducing memory requirements. Finally, we demonstrate that warm-starting in bilevel optimization naturally induces an implicit form of truncation, providing a practical remedy. Our theoretical findings are supported by numerical experiments on representative examples.
Paper Structure (33 sections, 6 theorems, 34 equations, 14 figures, 2 tables, 1 algorithm)

This paper contains 33 sections, 6 theorems, 34 equations, 14 figures, 2 tables, 1 algorithm.

Key Result

Theorem 2

Suppose $\mathcal{A}_{}$ satisfies Assumption ass:FixMap:Contraction. Then there exists a $C^1$-smooth map $\bm x^{\star}\colon \mathcal{U} \to \mathcal{X}$ such that for all $\bm u \in \mathcal{U}$, $\bm x^{\star}(\bm u) = \mathcal{A}_{}(\bm x^{\star}(\bm u), \bm u)$ is the unique fixed-point of $\ where the maps $B_{}\colon U \to \mathcal{L}(\mathcal{X}, \mathcal{X})$ and $C_{}\colon U \to \math

Figures (14)

  • Figure 1: Iterate $\bm x ^{(k)} (\bm u)$ vs derivative $D_{}\bm x ^{(k)} (\bm u)$ error plot for gradient descent applied to $f (\bm x, u) \coloneqq \Vert A\bm x - \bm b \Vert_{}^2/2 + u\Vert \bm x \Vert_{}^2/2$. Unlike $\bm x ^{(k)} (\bm u)$, $D_{}\bm x ^{(k)} (\bm u)$initially drifts away from its limit before eventually coming back to it. We provide a non-asymptotic understanding of this transient behavior, called the curse of unrolling, and study simple ways to mitigate it.
  • Figure 2: Error evolution of $e_{}^{(k)}(\bm u)$, $\dot e_{}^{(k)}(\bm u)$, and $\bar{e}^{(k)} (\bm u)$ generated by gradient descent applied to $f (\bm x, u) \coloneqq \Vert A\bm x - \bm b \Vert_{}^2/2 + u\Vert \bm x \Vert_{}^2/2$. The dashed lines denote the bounds given in \ref{['eq:conv:rate']} and \ref{['eq:AD:conv:rate']}. The vertical lines denote $\dot k$ and $\bar{k}$ defined in Sections \ref{['ssec:curse:forward']} and \ref{['ssec:curse:reverse']} respectively.
  • Figure 3: Late-start / truncation behavior for $N = 2$, $\alpha = 2 / (L + m)$, and $\rho \approx \fpeval{round(\ratearray[1,2], 6)}$.
  • Figure 4: Late-start / truncation behavior for $N = 5$, $\alpha = 2 / (L + m)$, and $\rho \approx \fpeval{round(\ratearray[1,2], 6)}$.
  • Figure 5: Late-start / truncation behavior for $N = 10$, $\alpha = 2 / (L + m)$, and $\rho \approx \fpeval{round(\ratearray[1,2], 6)}$.
  • ...and 9 more figures

Theorems & Definitions (22)

  • Remark 1
  • Theorem 2
  • Remark 3
  • Lemma 4
  • proof
  • Remark 5
  • Lemma 6
  • proof
  • Theorem 7: Convergence of Derivative Iterates
  • proof
  • ...and 12 more