Table of Contents
Fetching ...

Lowering PyTorch's Memory Consumption for Selective Differentiation

Samarth Bhatia, Felix Dangel

TL;DR

The paper tackles the memory bottleneck of PyTorch's automatic differentiation in the context of selective differentiation, where gradients are requested for only a subset of graph leaves. It introduces differentiability-aware, drop-in memory-saving layers that discard layer inputs when weights are non-differentiable, preserving forward/backward behavior while reducing AD graph storage. Empirical evaluations on CNNs and Transformers show memory reductions without runtime penalties, though gains depend on layer interactions and activation choices; a mask-based ReLU and evaluation-mode BN can substantially amplify savings. The approach is a practical, easy-to-use converter that enables scalable selective fine-tuning and highlights avenues for further AD-optimization research.

Abstract

Memory is a limiting resource for many deep learning tasks. Beside the neural network weights, one main memory consumer is the computation graph built up by automatic differentiation (AD) for backpropagation. We observe that PyTorch's current AD implementation neglects information about parameter differentiability when storing the computation graph. This information is useful though to reduce memory whenever gradients are requested for a parameter subset, as is the case in many modern fine-tuning tasks. Specifically, inputs to layers that act linearly in their parameters (dense, convolution, or normalization layers) can be discarded whenever the parameters are marked as non-differentiable. We provide a drop-in, differentiability-agnostic implementation of such layers and demonstrate its ability to reduce memory without affecting run time.

Lowering PyTorch's Memory Consumption for Selective Differentiation

TL;DR

The paper tackles the memory bottleneck of PyTorch's automatic differentiation in the context of selective differentiation, where gradients are requested for only a subset of graph leaves. It introduces differentiability-aware, drop-in memory-saving layers that discard layer inputs when weights are non-differentiable, preserving forward/backward behavior while reducing AD graph storage. Empirical evaluations on CNNs and Transformers show memory reductions without runtime penalties, though gains depend on layer interactions and activation choices; a mask-based ReLU and evaluation-mode BN can substantially amplify savings. The approach is a practical, easy-to-use converter that enables scalable selective fine-tuning and highlights avenues for further AD-optimization research.

Abstract

Memory is a limiting resource for many deep learning tasks. Beside the neural network weights, one main memory consumer is the computation graph built up by automatic differentiation (AD) for backpropagation. We observe that PyTorch's current AD implementation neglects information about parameter differentiability when storing the computation graph. This information is useful though to reduce memory whenever gradients are requested for a parameter subset, as is the case in many modern fine-tuning tasks. Specifically, inputs to layers that act linearly in their parameters (dense, convolution, or normalization layers) can be discarded whenever the parameters are marked as non-differentiable. We provide a drop-in, differentiability-agnostic implementation of such layers and demonstrate its ability to reduce memory without affecting run time.
Paper Structure (21 sections, 9 figures, 5 tables)

This paper contains 21 sections, 9 figures, 5 tables.

Figures (9)

  • Figure 1: PyTorch's AD is sometimes not agnostic to parameter differentiability. We consider a deep CNN made of size-preserving convolutions and measure the forward pass's peak memory when processing a mini-batch of size (256, 8, 256, 256), requiring 512 MiB memory. Memory increases linearly in the number of layers when all parameters are marked differentiable and remains constant when all parameters are marked non-differentiable. Surprisingly, when only one layer's parameters are marked as differentiable the memory increases as if all subsequent parameters were marked differentiable. Our drop-in solution stores layer inputs depending on parameter differentiability and reduces memory compared to the current PyTorch implementation.
  • Figure 2: PyTorch's behaviour of storing the computation graph, illustrated on a convolution ${\bm{Z}} = {\bm{W}} * {\bm{X}}$. PyTorch stores the layer input whenever it is differentiable, although this is not necessary if the weight does not require gradients. Our approach uses this information to discard the layer input if possible. See \ref{['supp:torchviz_diagrams']} for computation graphs visualized with torchviz.
  • Figure 3: Probing different PyTorch layer's for their awareness of parameter differentiability. (\ref{['subfig:visual-abstract-linear']}) PyTorch's is agnostic to parameter differentiability. (\ref{['subfig:visual-abstract-conv1d']}, \ref{['subfig:visual-abstract-conv2d']}, \ref{['subfig:visual-abstract-conv3d']}) PyTorch's , (\ref{['subfig:visual-abstract-conv-transpose1d']}, \ref{['subfig:visual-abstract-conv-transpose2d']}, \ref{['subfig:visual-abstract-conv-transpose3d']}) , and (\ref{['subfig:visual-abstract-bn2d']}) in evaluation mode are not agnostic to parameter differentiability.
  • Figure 4: Probing different PyTorch layer's for their awareness of parameter differentiability with .
  • Figure 5: Computation graphs of a convolution layer for the Input case. Even though the input is differentiable, PyTorch saves it (as can be seen inside the node - the input is of shape ). MemSave on the other hand, does not save the input.
  • ...and 4 more figures