Table of Contents
Fetching ...

JAX-Privacy: A library for differentially private machine learning

Ryan McKenna, Galen Andrew, Borja Balle, Vadym Doroshenko, Arun Ganesh, Weiwei Kong, Alex Kurakin, Brendan McMahan, Mikhail Pravilov

TL;DR

JAX-Privacy tackles the challenge of deploying robust, verifiable differential privacy (DP) for ML on sensitive data by providing a centralized, modular library of DP primitives. It delivers drop-in, JAX-native components for batch selection, per-example gradient clipping, noise addition, accounting, auditing, and a Keras API to bridge research flexibility and production readiness. The library supports a range of mechanisms, including DP-SGD and correlated-noise strategies, with integrated budgeting via $ε$ and $δ$ and auditing to quantify empirical leakage. Empirical results demonstrate competitive throughput relative to non-private baselines across several architectures, with significantly reduced overhead compared to prior DP libraries, enabling practical private training at scale.

Abstract

JAX-Privacy is a library designed to simplify the deployment of robust and performant mechanisms for differentially private machine learning. Guided by design principles of usability, flexibility, and efficiency, JAX-Privacy serves both researchers requiring deep customization and practitioners who want a more out-of-the-box experience. The library provides verified, modular primitives for critical components for all aspects of the mechanism design including batch selection, gradient clipping, noise addition, accounting, and auditing, and brings together a large body of recent research on differentially private ML.

JAX-Privacy: A library for differentially private machine learning

TL;DR

JAX-Privacy tackles the challenge of deploying robust, verifiable differential privacy (DP) for ML on sensitive data by providing a centralized, modular library of DP primitives. It delivers drop-in, JAX-native components for batch selection, per-example gradient clipping, noise addition, accounting, auditing, and a Keras API to bridge research flexibility and production readiness. The library supports a range of mechanisms, including DP-SGD and correlated-noise strategies, with integrated budgeting via and and auditing to quantify empirical leakage. Empirical results demonstrate competitive throughput relative to non-private baselines across several architectures, with significantly reduced overhead compared to prior DP libraries, enabling practical private training at scale.

Abstract

JAX-Privacy is a library designed to simplify the deployment of robust and performant mechanisms for differentially private machine learning. Guided by design principles of usability, flexibility, and efficiency, JAX-Privacy serves both researchers requiring deep customization and practitioners who want a more out-of-the-box experience. The library provides verified, modular primitives for critical components for all aspects of the mechanism design including batch selection, gradient clipping, noise addition, accounting, and auditing, and brings together a large body of recent research on differentially private ML.
Paper Structure (14 sections, 1 figure, 2 tables)