Table of Contents
Fetching ...

GPU Memory Usage Optimization for Backward Propagation in Deep Network Training

Ding-Yong Hong, Tzu-Hsien Tsai, Ning Wang, Pangfeng Liu, Jan-Jan Wu

TL;DR

Deep network training incurs substantial GPU memory pressure from intermediate activations. The authors cast checkpoint selection as a memory-minimization optimization, first with an $O(n^3)$ dynamic-programming solution and then a refined $O(n)$-time dynamic programming approach that mirrors PyTorch memory behavior. They demonstrate that aligning the memory model with the framework enables significantly tighter peak memory bounds and comparable training times, outperforming prior $O(\sqrt{n})$ and ACG-based methods in practice. The work enables training larger models within fixed GPU memory and provides a principled, efficient method for memory-aware training in modern DL platforms.

Abstract

In modern Deep Learning, it has been a trend to design larger Deep Neural Networks (DNNs) for the execution of more complex tasks and better accuracy. On the other hand, Convolutional Neural Networks (CNNs) have become the standard method for most of computer vision tasks. However, the memory allocation for the intermediate data in convolution layers can cause severe memory pressure during model training. Many solutions have been proposed to resolve the problem. Besides hardware-dependent solutions, a general methodology rematerialization can reduce GPU memory usage by trading computation for memory efficiently. The idea is to select a set of intermediate results during the forward phase as checkpoints, and only save them in memory to reduce memory usage. The backward phase recomputes the intermediate data from the closest checkpoints in memory as needed. This recomputation increases execution time but saves memory by not storing all intermediate results in memory during the forward phase. In this paper, we will focus on efficiently finding the optimal checkpoint subset to achieve the least peak memory usage during the model training. We first describe the theoretical background of the training of a neural network using mathematical equations. We use these equations to identify all essential data required during both forward and backward phases to compute the gradient of weights of the model. We first identify the checkpoint selection problem and propose a dynamic programming algorithm with time complexity O(n3) to solve the problem of finding the optimal checkpoint subset. With extensive experiments, we formulate a more accurate description of the problem using our theoretical analysis and revise the objective function based on the tracing, and propose an O(n)-time algorithm for finding the optimal checkpoint subset.

GPU Memory Usage Optimization for Backward Propagation in Deep Network Training

TL;DR

Deep network training incurs substantial GPU memory pressure from intermediate activations. The authors cast checkpoint selection as a memory-minimization optimization, first with an dynamic-programming solution and then a refined -time dynamic programming approach that mirrors PyTorch memory behavior. They demonstrate that aligning the memory model with the framework enables significantly tighter peak memory bounds and comparable training times, outperforming prior and ACG-based methods in practice. The work enables training larger models within fixed GPU memory and provides a principled, efficient method for memory-aware training in modern DL platforms.

Abstract

In modern Deep Learning, it has been a trend to design larger Deep Neural Networks (DNNs) for the execution of more complex tasks and better accuracy. On the other hand, Convolutional Neural Networks (CNNs) have become the standard method for most of computer vision tasks. However, the memory allocation for the intermediate data in convolution layers can cause severe memory pressure during model training. Many solutions have been proposed to resolve the problem. Besides hardware-dependent solutions, a general methodology rematerialization can reduce GPU memory usage by trading computation for memory efficiently. The idea is to select a set of intermediate results during the forward phase as checkpoints, and only save them in memory to reduce memory usage. The backward phase recomputes the intermediate data from the closest checkpoints in memory as needed. This recomputation increases execution time but saves memory by not storing all intermediate results in memory during the forward phase. In this paper, we will focus on efficiently finding the optimal checkpoint subset to achieve the least peak memory usage during the model training. We first describe the theoretical background of the training of a neural network using mathematical equations. We use these equations to identify all essential data required during both forward and backward phases to compute the gradient of weights of the model. We first identify the checkpoint selection problem and propose a dynamic programming algorithm with time complexity O(n3) to solve the problem of finding the optimal checkpoint subset. With extensive experiments, we formulate a more accurate description of the problem using our theoretical analysis and revise the objective function based on the tracing, and propose an O(n)-time algorithm for finding the optimal checkpoint subset.

Paper Structure

This paper contains 31 sections, 7 theorems, 16 equations, 9 figures, 3 tables, 6 algorithms.

Key Result

Theorem 1

The checkpoint selection problem is $O(n^3)$-time solvable, where $n$ is the number of neural network layers.

Figures (9)

  • Figure 1: An example of three checkpoints (in black) and two segments: {d1} and {d3, d4} (in white)
  • Figure 2: The values of $M(\cdot)$ and $U(i, \cdot)$ when we compute $M(2)$ and $M(1)$. The $x$-coordinate represents the index of $M(\cdot)$ and $U(i, \cdot)$, while the $y$-coordinate represents their corresponding values. Black dots indicate the values of $M(\cdot)$ present in the queue $Q_i$, and stars represent the values of $U(i, \cdot)$. The $x$-coordinate $i$ without a black dot indicates that $M(i)$ has been removed from the queue $Q$. The $j$ below the $x$-axis marks the value of $j$ after the "while" loop in Algorithm 6. The $j^*(i)$ below the $x$-axis marks the value of updated $j^*_(i)$.
  • Figure 3: Training phase indices of VGG-19.
  • Figure 4: Training phase indices of AlexNet.
  • Figure 5: GPU Memory Usage: Algorithm Prediction vs PyTorch Report on VGG-19 with Checkpoint Subset {3, 11, 24}.
  • ...and 4 more figures

Theorems & Definitions (12)

  • Theorem 1
  • Theorem 2
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • Lemma 4
  • proof
  • ...and 2 more