CrypTorch: PyTorch-based Auto-tuning Compiler for Machine Learning with Multi-party Computation
Jinyu Liu, Gang Tan, Kiwan Maeng
TL;DR
CrypTorch tackles the latency bottleneck in MPC-based ML by introducing a modular, multi-stage compiler that separates operator approximation from the MPC runtime. It builds on CrypTen++ as a stronger baseline and adds an automatic, per-layer approximation tuner that searches for fast yet accurate implementations of non-native MPC operators like Softmax and GELU. The approach yields 1.20–1.80× end-to-end speedups over CrypTen++ and up to 8.6× over CrypTen, with acceptable accuracy loss under relaxed constraints, and is integrated as an extension to PyTorch 2's compiler for ease of adoption. By providing an easy interface to add new approximations and a reusable IR, CrypTorch enables rapid exploration of better MPC kernels and paves the way for broader deployment of privacy-preserving ML workloads.
Abstract
Machine learning (ML) involves private data and proprietary model parameters. MPC-based ML allows multiple parties to collaboratively run an ML workload without sharing their private data or model parameters using multi-party computing (MPC). Because MPC cannot natively run ML operations such as Softmax or GELU, existing frameworks use different approximations. Our study shows that, on a well-optimized framework, these approximations often become the dominating bottleneck. Popular approximations are often insufficiently accurate or unnecessarily slow, and these issues are hard to identify and fix in existing frameworks. To tackle this issue, we propose a compiler for MPC-based ML, CrypTorch. CrypTorch disentangles these approximations with the rest of the MPC runtime, allows easily adding new approximations through its programming interface, and automatically selects approximations to maximize both performance and accuracy. Built as an extension to PyTorch 2's compiler, we show that CrypTorch's auto-tuning alone provides 1.20--1.7$\times$ immediate speedup without sacrificing accuracy, and 1.31--1.8$\times$ speedup when some accuracy degradation is allowed, compared to our well-optimized baseline. Combined with better engineering and adoption of state-of-the-art practices, the entire framework brings 3.22--8.6$\times$ end-to-end speedup compared to the popular framework, CrypTen.
