Table of Contents
Fetching ...

Reversible Deep Equilibrium Models

Sam McCallum, Kamran Arora, James Foster

TL;DR

This work tackles instability and gradient-approximation challenges in Deep Equilibrium Models (DEQs) by introducing Reversible Deep Equilibrium Models (RevDEQs). RevDEQs use an algebraically reversible fixed-point solver to enable exact gradients with constant memory, reducing function evaluations and regularisation needs. Empirical results on language modelling (Wikitext-103) and image classification (CIFAR-10) show RevDEQs outperform existing implicit and competitive explicit baselines, sometimes matching or exceeding strong ResNet/Transformer performance with far fewer evaluations. The approach highlights the practical potential of reversible implicit depth, suggesting broad applicability and improved GPU efficiency for large-scale tasks.

Abstract

Deep Equilibrium Models (DEQs) are an interesting class of implicit model where the model output is implicitly defined as the fixed point of a learned function. These models have been shown to outperform explicit (fixed-depth) models in large-scale tasks by trading many deep layers for a single layer that is iterated many times. However, gradient calculation through DEQs is approximate. This often leads to unstable training dynamics and requires regularisation or many function evaluations to fix. Here, we introduce Reversible Deep Equilibrium Models (RevDEQs) that allow for exact gradient calculation, no regularisation and far fewer function evaluations than DEQs. We show that RevDEQs significantly improve performance on language modelling and image classification tasks against comparable implicit and explicit models.

Reversible Deep Equilibrium Models

TL;DR

This work tackles instability and gradient-approximation challenges in Deep Equilibrium Models (DEQs) by introducing Reversible Deep Equilibrium Models (RevDEQs). RevDEQs use an algebraically reversible fixed-point solver to enable exact gradients with constant memory, reducing function evaluations and regularisation needs. Empirical results on language modelling (Wikitext-103) and image classification (CIFAR-10) show RevDEQs outperform existing implicit and competitive explicit baselines, sometimes matching or exceeding strong ResNet/Transformer performance with far fewer evaluations. The approach highlights the practical potential of reversible implicit depth, suggesting broad applicability and improved GPU efficiency for large-scale tasks.

Abstract

Deep Equilibrium Models (DEQs) are an interesting class of implicit model where the model output is implicitly defined as the fixed point of a learned function. These models have been shown to outperform explicit (fixed-depth) models in large-scale tasks by trading many deep layers for a single layer that is iterated many times. However, gradient calculation through DEQs is approximate. This often leads to unstable training dynamics and requires regularisation or many function evaluations to fix. Here, we introduce Reversible Deep Equilibrium Models (RevDEQs) that allow for exact gradient calculation, no regularisation and far fewer function evaluations than DEQs. We show that RevDEQs significantly improve performance on language modelling and image classification tasks against comparable implicit and explicit models.

Paper Structure

This paper contains 39 sections, 6 theorems, 31 equations, 2 figures, 4 tables, 1 algorithm.

Key Result

Theorem 3.1

Let $\{\mathbf{y}_n, \mathbf{z}_n\}_{n \geq 0}$ be the sequence obtained by the reversible iteration in equation eq:reversible-fixed-point and consider a contractive function $f: \mathbb{R}^d \rightarrow \mathbb{R}^d$ with Lipschitz constant $0 < k < 1$. Then, for all $0<\beta<2/(k+1)$ the pair $\ma

Figures (2)

  • Figure 1: Example of the forward and backward passes in RevDEQ. Starting at $\{\mathbf{y}_0, \mathbf{z}_0\}$ we iterate the forward step until we reach (an approximation to) the fixed point $\{\mathbf{y}^*, \mathbf{z}^*\}$. On the backward pass we reverse each forward step until we return to the initial condition.
  • Figure 2: A single scale of the multi-scale implicit ResNet architecture.

Theorems & Definitions (11)

  • Theorem 3.1
  • proof
  • Theorem : Banach fixed point theorem
  • Theorem A.1
  • proof
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Corollary A.1
  • ...and 1 more