Deep Equilibrium Models
Shaojie Bai, J. Zico Kolter, Vladlen Koltun
TL;DR
The paper addresses the memory bottleneck in deep sequence models by reframing depth as an equilibrium problem. It introduces the deep equilibrium model (DEQ), which directly solves for a sequence-level fixed point z* satisfying z* = f_theta(z*; x) via root-finding and differentiates through this fixed point with implicit differentiation to achieve constant memory during training. The authors instantiate DEQ with TrellisNet and weight-tied transformers, demonstrating competitive perplexities on PTB and WikiText-103 while reducing memory usage dramatically (up to about 88% in their experiments). They also establish that stacking DEQs offers no extra representational power and discuss practical aspects such as initialization, convergence, and runtime trade-offs. Overall, DEQ offers a versatile, memory-efficient framework for sequence modeling with broad applicability to existing architectures and large-scale tasks.
Abstract
We present a new approach to modeling sequential data: the deep equilibrium model (DEQ). Motivated by an observation that the hidden layers of many existing deep sequence models converge towards some fixed point, we propose the DEQ approach that directly finds these equilibrium points via root-finding. Such a method is equivalent to running an infinite depth (weight-tied) feedforward network, but has the notable advantage that we can analytically backpropagate through the equilibrium point using implicit differentiation. Using this approach, training and prediction in these networks require only constant memory, regardless of the effective "depth" of the network. We demonstrate how DEQs can be applied to two state-of-the-art deep sequence models: self-attention transformers and trellis networks. On large-scale language modeling tasks, such as the WikiText-103 benchmark, we show that DEQs 1) often improve performance over these state-of-the-art models (for similar parameter counts); 2) have similar computational requirements to existing models; and 3) vastly reduce memory consumption (often the bottleneck for training large sequence models), demonstrating an up-to 88% memory reduction in our experiments. The code is available at https://github.com/locuslab/deq .
