Table of Contents
Fetching ...

Efficient and Scalable Implementation of Differentially Private Deep Learning without Shortcuts

Sebastian Rodriguez Beltran, Marlon Tobaben, Joonas Jälkö, Niki Loppi, Antti Honkela

TL;DR

The paper addresses the high computational cost of training deep models under differential privacy using DP-SGD with Poisson subsampling. It introduces a JAX-based, Poisson-compliant approach (Masked DP-SGD) that avoids recompilation and leverages efficient clipping and lower precision to reduce overhead, while validating via extensive experiments across vision and NLP tasks up to 80 GPUs. Key contributions include re-implementing Poisson-subsampled DP-SGD, a detailed cost analysis, and practical speedups from clipping optimizations and TF32 precision, plus an open-source library. The work demonstrates that DP-SGD can scale more favorably than non-private training under distributed conditions, offering a viable path to scalable private deep learning. Overall, it provides actionable guidance and tooling to implement, benchmark, and deploy DP-SGD with correct privacy accounting in real-world settings.

Abstract

Differentially private stochastic gradient descent (DP-SGD) is the standard algorithm for training machine learning models under differential privacy (DP). The most common DP-SGD privacy accountants rely on Poisson subsampling to ensure the theoretical DP guarantees. Implementing computationally efficient DP-SGD with Poisson subsampling is not trivial, which leads many implementations to taking a shortcut by using computationally faster subsampling. We quantify the computational cost of training deep learning models under DP by implementing and benchmarking efficient methods with the correct Poisson subsampling. We find that using the naive implementation of DP-SGD with Opacus in PyTorch has a throughput between 2.6 and 8 times lower than that of SGD. However, efficient gradient clipping implementations like Ghost Clipping can roughly halve this cost. We propose an alternative computationally efficient implementation of DP-SGD with JAX that uses Poisson subsampling and performs comparably with efficient clipping optimizations based on PyTorch. We study the scaling behavior using up to 80 GPUs and find that DP-SGD scales better than SGD. We share our library at https://github.com/DPBayes/Towards-Efficient-Scalable-Training-DP-DL.

Efficient and Scalable Implementation of Differentially Private Deep Learning without Shortcuts

TL;DR

The paper addresses the high computational cost of training deep models under differential privacy using DP-SGD with Poisson subsampling. It introduces a JAX-based, Poisson-compliant approach (Masked DP-SGD) that avoids recompilation and leverages efficient clipping and lower precision to reduce overhead, while validating via extensive experiments across vision and NLP tasks up to 80 GPUs. Key contributions include re-implementing Poisson-subsampled DP-SGD, a detailed cost analysis, and practical speedups from clipping optimizations and TF32 precision, plus an open-source library. The work demonstrates that DP-SGD can scale more favorably than non-private training under distributed conditions, offering a viable path to scalable private deep learning. Overall, it provides actionable guidance and tooling to implement, benchmark, and deploy DP-SGD with correct privacy accounting in real-world settings.

Abstract

Differentially private stochastic gradient descent (DP-SGD) is the standard algorithm for training machine learning models under differential privacy (DP). The most common DP-SGD privacy accountants rely on Poisson subsampling to ensure the theoretical DP guarantees. Implementing computationally efficient DP-SGD with Poisson subsampling is not trivial, which leads many implementations to taking a shortcut by using computationally faster subsampling. We quantify the computational cost of training deep learning models under DP by implementing and benchmarking efficient methods with the correct Poisson subsampling. We find that using the naive implementation of DP-SGD with Opacus in PyTorch has a throughput between 2.6 and 8 times lower than that of SGD. However, efficient gradient clipping implementations like Ghost Clipping can roughly halve this cost. We propose an alternative computationally efficient implementation of DP-SGD with JAX that uses Poisson subsampling and performs comparably with efficient clipping optimizations based on PyTorch. We study the scaling behavior using up to 80 GPUs and find that DP-SGD scales better than SGD. We share our library at https://github.com/DPBayes/Towards-Efficient-Scalable-Training-DP-DL.

Paper Structure

This paper contains 24 sections, 1 theorem, 5 equations, 21 figures, 10 tables.

Key Result

Lemma 1

Poisson subsampling from a data set $D$ of size $N$ with subsampling rate $q$ is identical to the following procedure

Figures (21)

  • Figure 1: Relative throughput (FP32) to the respective non private baseline (higher is better) on NVIDIA A100. For each optimization method and each model size, we divide its throughput with the non-private counterpart. Throughput is the number of processed instances per second. In this benchmark we distinguish between precision modes. They are available on both frameworks and significantly improve the throughput for the different DP-SGD implementations.
  • Figure 2: A variant to the DP-SGD algorithm dp-sgd-rajkumar-2012dp-sgd-song-2013dp-sgd-abadi-2016 that supports virtual batching of logical batches of size $b$ into smaller phyiscal batches of size $p$ that can be processed in memory.
  • Figure 3: Our prosed algorithm Masked DP-SGD with differences to the default virtual batching algorithm in \ref{['alg:dpsgd']} highlighted in blue.
  • Figure 4: Expected additionally computed gradients per step of our proposed method and chua2024scalable. We assume that the dataset size is $N=\numprint{50000}$ and simulate with $\epsilon \in \{1,8\}$ at $\delta=10^{-5}$ with 40 epochs for all $L/N$, which means that the the number of iterations $T=40\times N/L$. We set $\tau = 10^{-5}$ as done by chua2024scalable. We also plot the simple upper bound that we computed a the beginning of the subsection with dotted lines. We observe that our method requires significantly fewer additional gradients for $p\leq256$.
  • Figure 5: Expected additionally computed gradients per step of our proposed method. This is an illustration of an imaginary scenario where we assume that the dataset size is $N=50000$ and simulate with $\frac{L}{N}=0.5$. Choosing the maximum physical batch size $p=1024$ is not optimal in terms of expected additionally computed gradients per step.
  • ...and 16 more figures

Theorems & Definitions (2)

  • Lemma 1
  • proof