Table of Contents
Fetching ...

Deep Equilibrium Models

Shaojie Bai, J. Zico Kolter, Vladlen Koltun

TL;DR

The paper addresses the memory bottleneck in deep sequence models by reframing depth as an equilibrium problem. It introduces the deep equilibrium model (DEQ), which directly solves for a sequence-level fixed point z* satisfying z* = f_theta(z*; x) via root-finding and differentiates through this fixed point with implicit differentiation to achieve constant memory during training. The authors instantiate DEQ with TrellisNet and weight-tied transformers, demonstrating competitive perplexities on PTB and WikiText-103 while reducing memory usage dramatically (up to about 88% in their experiments). They also establish that stacking DEQs offers no extra representational power and discuss practical aspects such as initialization, convergence, and runtime trade-offs. Overall, DEQ offers a versatile, memory-efficient framework for sequence modeling with broad applicability to existing architectures and large-scale tasks.

Abstract

We present a new approach to modeling sequential data: the deep equilibrium model (DEQ). Motivated by an observation that the hidden layers of many existing deep sequence models converge towards some fixed point, we propose the DEQ approach that directly finds these equilibrium points via root-finding. Such a method is equivalent to running an infinite depth (weight-tied) feedforward network, but has the notable advantage that we can analytically backpropagate through the equilibrium point using implicit differentiation. Using this approach, training and prediction in these networks require only constant memory, regardless of the effective "depth" of the network. We demonstrate how DEQs can be applied to two state-of-the-art deep sequence models: self-attention transformers and trellis networks. On large-scale language modeling tasks, such as the WikiText-103 benchmark, we show that DEQs 1) often improve performance over these state-of-the-art models (for similar parameter counts); 2) have similar computational requirements to existing models; and 3) vastly reduce memory consumption (often the bottleneck for training large sequence models), demonstrating an up-to 88% memory reduction in our experiments. The code is available at https://github.com/locuslab/deq .

Deep Equilibrium Models

TL;DR

The paper addresses the memory bottleneck in deep sequence models by reframing depth as an equilibrium problem. It introduces the deep equilibrium model (DEQ), which directly solves for a sequence-level fixed point z* satisfying z* = f_theta(z*; x) via root-finding and differentiates through this fixed point with implicit differentiation to achieve constant memory during training. The authors instantiate DEQ with TrellisNet and weight-tied transformers, demonstrating competitive perplexities on PTB and WikiText-103 while reducing memory usage dramatically (up to about 88% in their experiments). They also establish that stacking DEQs offers no extra representational power and discuss practical aspects such as initialization, convergence, and runtime trade-offs. Overall, DEQ offers a versatile, memory-efficient framework for sequence modeling with broad applicability to existing architectures and large-scale tasks.

Abstract

We present a new approach to modeling sequential data: the deep equilibrium model (DEQ). Motivated by an observation that the hidden layers of many existing deep sequence models converge towards some fixed point, we propose the DEQ approach that directly finds these equilibrium points via root-finding. Such a method is equivalent to running an infinite depth (weight-tied) feedforward network, but has the notable advantage that we can analytically backpropagate through the equilibrium point using implicit differentiation. Using this approach, training and prediction in these networks require only constant memory, regardless of the effective "depth" of the network. We demonstrate how DEQs can be applied to two state-of-the-art deep sequence models: self-attention transformers and trellis networks. On large-scale language modeling tasks, such as the WikiText-103 benchmark, we show that DEQs 1) often improve performance over these state-of-the-art models (for similar parameter counts); 2) have similar computational requirements to existing models; and 3) vastly reduce memory consumption (often the bottleneck for training large sequence models), demonstrating an up-to 88% memory reduction in our experiments. The code is available at https://github.com/locuslab/deq .

Paper Structure

This paper contains 42 sections, 5 theorems, 27 equations, 4 figures, 4 tables.

Key Result

Theorem 1

(Gradient of the Equilibrium Model) Let $\mathbf{z}_{1:T}^\star \in \mathbb{R}^{T \times d}$ be an equilibrium hidden sequence with length $T$ and dimensionality $d$, and $\mathbf{y}_{1:T} \in \mathbb{R}^{T \times q}$ the ground-truth (target) sequence. Let $h: \mathbb{R}^d \rightarrow \mathbb{R}^q Then the loss gradient w.r.t. $(\cdot)$ (for instance, $\theta$ or $\mathbf{x}_{1:T}$) is where $J

Figures (4)

  • Figure 1: Comparison of the DEQ with conventional weight-tied deep networks.
  • Figure 2: Left: number of Broyden iterations in forward and backward passes gradually grows with epochs. Right: DEQ-Transformer finds the equilibrium in a stable and efficient manner (whereas the deep transformer could oscillate around the fixed point, even when one exists).
  • Figure 3: DEQ can be accelerated by leveraging higher tolerance $\varepsilon$ (left) or a lower Broyden iteration limit (right). In general, poor estimates of the equilibrium can hurt DEQ performances.
  • Figure 4: The convergence of intermediate activations in TrellisNet (with kernel size 2) and weight-tied transformers on different sequence lengths.

Theorems & Definitions (5)

  • Theorem 1
  • Theorem 2
  • Theorem 1
  • Theorem 2
  • Theorem 3