Table of Contents
Fetching ...

CrypTorch: PyTorch-based Auto-tuning Compiler for Machine Learning with Multi-party Computation

Jinyu Liu, Gang Tan, Kiwan Maeng

TL;DR

CrypTorch tackles the latency bottleneck in MPC-based ML by introducing a modular, multi-stage compiler that separates operator approximation from the MPC runtime. It builds on CrypTen++ as a stronger baseline and adds an automatic, per-layer approximation tuner that searches for fast yet accurate implementations of non-native MPC operators like Softmax and GELU. The approach yields 1.20–1.80× end-to-end speedups over CrypTen++ and up to 8.6× over CrypTen, with acceptable accuracy loss under relaxed constraints, and is integrated as an extension to PyTorch 2's compiler for ease of adoption. By providing an easy interface to add new approximations and a reusable IR, CrypTorch enables rapid exploration of better MPC kernels and paves the way for broader deployment of privacy-preserving ML workloads.

Abstract

Machine learning (ML) involves private data and proprietary model parameters. MPC-based ML allows multiple parties to collaboratively run an ML workload without sharing their private data or model parameters using multi-party computing (MPC). Because MPC cannot natively run ML operations such as Softmax or GELU, existing frameworks use different approximations. Our study shows that, on a well-optimized framework, these approximations often become the dominating bottleneck. Popular approximations are often insufficiently accurate or unnecessarily slow, and these issues are hard to identify and fix in existing frameworks. To tackle this issue, we propose a compiler for MPC-based ML, CrypTorch. CrypTorch disentangles these approximations with the rest of the MPC runtime, allows easily adding new approximations through its programming interface, and automatically selects approximations to maximize both performance and accuracy. Built as an extension to PyTorch 2's compiler, we show that CrypTorch's auto-tuning alone provides 1.20--1.7$\times$ immediate speedup without sacrificing accuracy, and 1.31--1.8$\times$ speedup when some accuracy degradation is allowed, compared to our well-optimized baseline. Combined with better engineering and adoption of state-of-the-art practices, the entire framework brings 3.22--8.6$\times$ end-to-end speedup compared to the popular framework, CrypTen.

CrypTorch: PyTorch-based Auto-tuning Compiler for Machine Learning with Multi-party Computation

TL;DR

CrypTorch tackles the latency bottleneck in MPC-based ML by introducing a modular, multi-stage compiler that separates operator approximation from the MPC runtime. It builds on CrypTen++ as a stronger baseline and adds an automatic, per-layer approximation tuner that searches for fast yet accurate implementations of non-native MPC operators like Softmax and GELU. The approach yields 1.20–1.80× end-to-end speedups over CrypTen++ and up to 8.6× over CrypTen, with acceptable accuracy loss under relaxed constraints, and is integrated as an extension to PyTorch 2's compiler for ease of adoption. By providing an easy interface to add new approximations and a reusable IR, CrypTorch enables rapid exploration of better MPC kernels and paves the way for broader deployment of privacy-preserving ML workloads.

Abstract

Machine learning (ML) involves private data and proprietary model parameters. MPC-based ML allows multiple parties to collaboratively run an ML workload without sharing their private data or model parameters using multi-party computing (MPC). Because MPC cannot natively run ML operations such as Softmax or GELU, existing frameworks use different approximations. Our study shows that, on a well-optimized framework, these approximations often become the dominating bottleneck. Popular approximations are often insufficiently accurate or unnecessarily slow, and these issues are hard to identify and fix in existing frameworks. To tackle this issue, we propose a compiler for MPC-based ML, CrypTorch. CrypTorch disentangles these approximations with the rest of the MPC runtime, allows easily adding new approximations through its programming interface, and automatically selects approximations to maximize both performance and accuracy. Built as an extension to PyTorch 2's compiler, we show that CrypTorch's auto-tuning alone provides 1.20--1.7 immediate speedup without sacrificing accuracy, and 1.31--1.8 speedup when some accuracy degradation is allowed, compared to our well-optimized baseline. Combined with better engineering and adoption of state-of-the-art practices, the entire framework brings 3.22--8.6 end-to-end speedup compared to the popular framework, CrypTen.

Paper Structure

This paper contains 52 sections, 7 equations, 17 figures.

Figures (17)

  • Figure 1: Scenarios for MPC-based ML.
  • Figure 2: Speedups from CrypTen++'s kernels compared to the original CrypTen's.
  • Figure 3: Overhead breakdown from CrypTen++. The thick contour shows the breakdown between higher-level operators (ReLU, Softmax, etc.). The colored patches within each thick contour further break down each higher-level operator into a set of lower-level operators (comparison, mul, etc.) supported directly by MPC. Linear, Conv2d, and MatMul operators are all shown as "Linear". While Max/MaxPool is technically a combination of comparisons and multiplications, we treat them as a separate operator (max) for simplicity.
  • Figure 4: Functional behavior (left) and the latency (right) for various MPC approximations for $e^x$.
  • Figure 5: Overview of CrypTorch. Compared to existing approaches where the approximation of an operation is implemented in the runtime library with the rest of the MPC-specific code, CrypTorch separates out the approximation into an earlier compilation stage and performs auto-tuning.
  • ...and 12 more figures