Efficient and Scalable Implementation of Differentially Private Deep Learning without Shortcuts
Sebastian Rodriguez Beltran, Marlon Tobaben, Joonas Jälkö, Niki Loppi, Antti Honkela
TL;DR
The paper addresses the high computational cost of training deep models under differential privacy using DP-SGD with Poisson subsampling. It introduces a JAX-based, Poisson-compliant approach (Masked DP-SGD) that avoids recompilation and leverages efficient clipping and lower precision to reduce overhead, while validating via extensive experiments across vision and NLP tasks up to 80 GPUs. Key contributions include re-implementing Poisson-subsampled DP-SGD, a detailed cost analysis, and practical speedups from clipping optimizations and TF32 precision, plus an open-source library. The work demonstrates that DP-SGD can scale more favorably than non-private training under distributed conditions, offering a viable path to scalable private deep learning. Overall, it provides actionable guidance and tooling to implement, benchmark, and deploy DP-SGD with correct privacy accounting in real-world settings.
Abstract
Differentially private stochastic gradient descent (DP-SGD) is the standard algorithm for training machine learning models under differential privacy (DP). The most common DP-SGD privacy accountants rely on Poisson subsampling to ensure the theoretical DP guarantees. Implementing computationally efficient DP-SGD with Poisson subsampling is not trivial, which leads many implementations to taking a shortcut by using computationally faster subsampling. We quantify the computational cost of training deep learning models under DP by implementing and benchmarking efficient methods with the correct Poisson subsampling. We find that using the naive implementation of DP-SGD with Opacus in PyTorch has a throughput between 2.6 and 8 times lower than that of SGD. However, efficient gradient clipping implementations like Ghost Clipping can roughly halve this cost. We propose an alternative computationally efficient implementation of DP-SGD with JAX that uses Poisson subsampling and performs comparably with efficient clipping optimizations based on PyTorch. We study the scaling behavior using up to 80 GPUs and find that DP-SGD scales better than SGD. We share our library at https://github.com/DPBayes/Towards-Efficient-Scalable-Training-DP-DL.
