Table of Contents
Fetching ...

Scaling Recurrent Neural Networks to a Billion Parameters with Zero-Order Optimization

Francois Chaubard, Mykel Kochenderfer

TL;DR

This work addresses the prohibitive memory bottlenecks of training long-context, billion-parameter RNNs with Backpropagation Through Time by introducing distributed Central-Difference Randomized Gradient Estimation (CD-RGE), a zero-order optimization approach that operates in an inference-like mode during training. By combining CD-RGE with FlashRNN and distributed computation, the method achieves comparable or superior convergence to BPTT across overfitting, transduction, and language modeling tasks, while dramatically reducing VRAM usage and enabling scalable training on accessible hardware. The finite-difference gradient estimates correspond to optimizing a smoothed surrogate loss $L_{\varepsilon}(\Theta) = \mathbb{E}_{p}[L(\Theta + \varepsilon p)]$, providing implicit regularization and improved generalization in practice. The results demonstrate that CD-RGE can match or exceed BPTT performance with orders-of-magnitude memory savings, and with distributed implementations can even surpass Transformer-like wall-clock efficiency, signaling a practical route to training large RNNs at scale with potential environmental and cost benefits.

Abstract

During inference, Recurrent Neural Networks (RNNs) scale constant in both FLOPs and GPU memory with increasing context length, as they compress all prior tokens into a fixed-size memory. In contrast, transformers scale linearly in FLOPs and, at best, linearly in memory during generation, since they must attend to all previous tokens explicitly. Despite this inference-time advantage, training large RNNs on long contexts remains impractical because standard optimization methods depend on Backpropagation Through Time (BPTT). BPTT requires retention of all intermediate activations during the forward pass, causing memory usage to scale linearly with both context length and model size. In this paper, we show that Zero-Order Optimization (ZOO) methods such as Random-vector Gradient Estimation (RGE) can successfully replace BPTT to train RNNs with convergence rates that match, or exceed BPTT by up to 19 fold, while using orders of magnitude less memory and cost, as the model remains in inference mode throughout training. We further demonstrate that Central-Difference RGE (CD-RGE) corresponds to optimizing a smoothed surrogate loss, inherently regularizing training and improving generalization. Our method matches or outperforms BPTT across three settings: (1) overfitting, (2) transduction, and (3) language modeling. Across all tasks, with sufficient perturbations, our models generalize as well as or better than those trained with BPTT, often in fewer steps. Despite the need for more forward passes per step, we can surpass BPTT wall-clock time per step using recent advancements such as FlashRNN and distributed inference.

Scaling Recurrent Neural Networks to a Billion Parameters with Zero-Order Optimization

TL;DR

This work addresses the prohibitive memory bottlenecks of training long-context, billion-parameter RNNs with Backpropagation Through Time by introducing distributed Central-Difference Randomized Gradient Estimation (CD-RGE), a zero-order optimization approach that operates in an inference-like mode during training. By combining CD-RGE with FlashRNN and distributed computation, the method achieves comparable or superior convergence to BPTT across overfitting, transduction, and language modeling tasks, while dramatically reducing VRAM usage and enabling scalable training on accessible hardware. The finite-difference gradient estimates correspond to optimizing a smoothed surrogate loss , providing implicit regularization and improved generalization in practice. The results demonstrate that CD-RGE can match or exceed BPTT performance with orders-of-magnitude memory savings, and with distributed implementations can even surpass Transformer-like wall-clock efficiency, signaling a practical route to training large RNNs at scale with potential environmental and cost benefits.

Abstract

During inference, Recurrent Neural Networks (RNNs) scale constant in both FLOPs and GPU memory with increasing context length, as they compress all prior tokens into a fixed-size memory. In contrast, transformers scale linearly in FLOPs and, at best, linearly in memory during generation, since they must attend to all previous tokens explicitly. Despite this inference-time advantage, training large RNNs on long contexts remains impractical because standard optimization methods depend on Backpropagation Through Time (BPTT). BPTT requires retention of all intermediate activations during the forward pass, causing memory usage to scale linearly with both context length and model size. In this paper, we show that Zero-Order Optimization (ZOO) methods such as Random-vector Gradient Estimation (RGE) can successfully replace BPTT to train RNNs with convergence rates that match, or exceed BPTT by up to 19 fold, while using orders of magnitude less memory and cost, as the model remains in inference mode throughout training. We further demonstrate that Central-Difference RGE (CD-RGE) corresponds to optimizing a smoothed surrogate loss, inherently regularizing training and improving generalization. Our method matches or outperforms BPTT across three settings: (1) overfitting, (2) transduction, and (3) language modeling. Across all tasks, with sufficient perturbations, our models generalize as well as or better than those trained with BPTT, often in fewer steps. Despite the need for more forward passes per step, we can surpass BPTT wall-clock time per step using recent advancements such as FlashRNN and distributed inference.

Paper Structure

This paper contains 17 sections, 15 equations, 5 figures, 2 tables, 2 algorithms.

Figures (5)

  • Figure 1: Iterations to overfit a fixed batch using BPTT versus CD-RGE on sequence length 100 at varying compute budgets. CD-RGE with 512 perturbations per step outperforms BPTT by up to 19× for small models (300k params) and 2× for larger models (270M params). BPTT cannot train models larger than 270M due to GPU memory constraints.
  • Figure 2: Validation loss trajectories for training large LSTMs on Penn-Treebank dataset on next character prediction with BPTT and CD-RGE with 512 perturbations per store across model sizes 1M, 10M, and 1.1B. As you can see, RGE provides very smooth convergence, while BPTT overfit quickly and bounces unevenly even with Adam. Both BPTT and CD-RGE at 10M converge to similar loss values. However, larger model sizes converge faster and to lower loss values suggesting more capacity will benefit larger RNNs.
  • Figure 3: Comparison of VRAM requirements for training (left) and inference (right) with CD-RGE. Each row varies a key factor — model size, sequence length, or batch size — while keeping others fixed.
  • Figure 4: Ackleys function convolved with a Rademacher distribution scaled by $\varepsilon$. As we increase $\varepsilon$, the function smooths up until a point when it becomes non-smooth again. Finding the right region is critical. For Ackleys it is around $\varepsilon$ = [1 to 1.7].
  • Figure 5: Visualization of shell smoothing in RGE using Rademacher-like perturbations $\epsilon p_i$ on a 3D sphere with radius $\epsilon \sqrt{d}$. Each point represents a directional probe $\epsilon p_i$ used to estimate the gradient via finite differences. The color denotes the loss value $L(\theta + \epsilon p_i)$, with darker regions indicating lower loss. The red arrow shows the estimated descent direction computed via RGE as per \ref{['eq:fde']}. RGE effectively integrates directional loss differences over the shell, yielding a smooth, low-variance estimate of the gradient that consistently points toward lower-loss regions, even in noisy or non-smooth settings.