Table of Contents
Fetching ...

Optimal Gradient Checkpointing for Sparse and Recurrent Architectures using Off-Chip Memory

Wadjih Bencheikh, Jan Finkbeiner, Emre Neftci

TL;DR

This work tackles memory bottlenecks in training sparse and recurrent networks on hardware with limited on-chip memory by introducing memory-efficient gradient checkpointing strategies tailored to sparse RNNs and Spiking Neural Networks (SNNs) and validating them on Graphcore IPUs. It presents four checkpointing schemes—Standard, Remote, Hierarchical, and Double—along with a formal memory-time modeling framework, showing how each trades memory for recomputation. Empirical results demonstrate that Double Checkpointing enables training sequences up to $10\times$ longer and networks up to $4\times$ larger with minimal overhead, largely by exploiting sparsity and fast recomputation and by reducing remote memory accesses. The findings suggest broad applicability to sparse and recurrent architectures across hardware platforms, with potential for substantial improvements in scalability and energy efficiency, including on-chip-only regimes where memory scales as $O(\sqrt[4]{T})$ under optimal settings.

Abstract

Recurrent neural networks (RNNs) are valued for their computational efficiency and reduced memory requirements on tasks involving long sequence lengths but require high memory-processor bandwidth to train. Checkpointing techniques can reduce the memory requirements by only storing a subset of intermediate states, the checkpoints, but are still rarely used due to the computational overhead of the additional recomputation phase. This work addresses these challenges by introducing memory-efficient gradient checkpointing strategies tailored for the general class of sparse RNNs and Spiking Neural Networks (SNNs). SNNs are energy efficient alternatives to RNNs thanks to their local, event-driven operation and potential neuromorphic implementation. We use the Intelligence Processing Unit (IPU) as an exemplary platform for architectures with distributed local memory. We exploit its suitability for sparse and irregular workloads to scale SNN training on long sequence lengths. We find that Double Checkpointing emerges as the most effective method, optimizing the use of local memory resources while minimizing recomputation overhead. This approach reduces dependency on slower large-scale memory access, enabling training on sequences over 10 times longer or 4 times larger networks than previously feasible, with only marginal time overhead. The presented techniques demonstrate significant potential to enhance scalability and efficiency in training sparse and recurrent networks across diverse hardware platforms, and highlights the benefits of sparse activations for scalable recurrent neural network training.

Optimal Gradient Checkpointing for Sparse and Recurrent Architectures using Off-Chip Memory

TL;DR

This work tackles memory bottlenecks in training sparse and recurrent networks on hardware with limited on-chip memory by introducing memory-efficient gradient checkpointing strategies tailored to sparse RNNs and Spiking Neural Networks (SNNs) and validating them on Graphcore IPUs. It presents four checkpointing schemes—Standard, Remote, Hierarchical, and Double—along with a formal memory-time modeling framework, showing how each trades memory for recomputation. Empirical results demonstrate that Double Checkpointing enables training sequences up to longer and networks up to larger with minimal overhead, largely by exploiting sparsity and fast recomputation and by reducing remote memory accesses. The findings suggest broad applicability to sparse and recurrent architectures across hardware platforms, with potential for substantial improvements in scalability and energy efficiency, including on-chip-only regimes where memory scales as under optimal settings.

Abstract

Recurrent neural networks (RNNs) are valued for their computational efficiency and reduced memory requirements on tasks involving long sequence lengths but require high memory-processor bandwidth to train. Checkpointing techniques can reduce the memory requirements by only storing a subset of intermediate states, the checkpoints, but are still rarely used due to the computational overhead of the additional recomputation phase. This work addresses these challenges by introducing memory-efficient gradient checkpointing strategies tailored for the general class of sparse RNNs and Spiking Neural Networks (SNNs). SNNs are energy efficient alternatives to RNNs thanks to their local, event-driven operation and potential neuromorphic implementation. We use the Intelligence Processing Unit (IPU) as an exemplary platform for architectures with distributed local memory. We exploit its suitability for sparse and irregular workloads to scale SNN training on long sequence lengths. We find that Double Checkpointing emerges as the most effective method, optimizing the use of local memory resources while minimizing recomputation overhead. This approach reduces dependency on slower large-scale memory access, enabling training on sequences over 10 times longer or 4 times larger networks than previously feasible, with only marginal time overhead. The presented techniques demonstrate significant potential to enhance scalability and efficiency in training sparse and recurrent networks across diverse hardware platforms, and highlights the benefits of sparse activations for scalable recurrent neural network training.

Paper Structure

This paper contains 11 sections, 12 equations, 5 figures.

Figures (5)

  • Figure 1: Comprehensive overview of the Backpropagation Through Time (BPTT) process, showcasing how intermediate states are handled and reconstructed. Execution traces for various gradient checkpointing strategies, including Standard, Chunk, Remote, and Double Checkpointing
  • Figure 2: Performance Comparison of Gradient Checkpointing Strategies Across Sequence Lengths and Model Sizes. Left: Time per batch as a function of sequence length, showing how Double Checkpointing maintains competitive training times compared to other methods, even for longer sequences. Middle: Peak local memory per tiles across sequence lengths, highlighting Double Checkpointing's ability to minimize memory usage while scaling. Right: Peak local memory per tiles as a function of model size, demonstrating Double Checkpointing's scalability and efficiency in handling larger models with T=300.
  • Figure 3: Memory Efficiency of Double Checkpointing Across Model Sizes and Configurations. The graphs compare the maximum peak local memory per tiles for Double Checkpointing versus the base implementation across varying sequence lengths (T) and batch sizes. Left: T=128 and batch_size=120. Middle: T=500 and batch_size=60. Right: T=900 and batch_size=60
  • Figure 4: Hyperparameter Study of Double and Hierarchical Checkpointing. This figure illustrates the impact of varying chunk size and the number of local checkpoints on training time (right) and memory usage (left) for Double Checkpointing (bottom) and Hierarchical Checkpointing (top).
  • Figure 5: Optimization of Remote Chunk Size for Double Checkpointing. Exploring the relationship between remote chunk size and the performance of Double Checkpointing, balancing memory efficiency and computational time. Left: Mean local memory peak per tile as a function of remote chunk size. Smaller remote chunk sizes reduce local memory usage but require more frequent synchronization, leading to inefficiencies. Right: Time per batch as a function of remote chunk size. Larger remote chunk sizes minimize synchronization overhead, resulting in faster training times but higher memory usage.