Table of Contents
Fetching ...

MPC-Minimized Secure LLM Inference

Deevashwer Rathee, Dacheng Li, Ion Stoica, Hao Zhang, Raluca Popa

TL;DR

The paper addresses privacy-preserving LLM inference by reducing MPC-related overhead through MPC-minimization. It introduces Marill, a fine-tuning framework that splits model weights into public and private components and applies high-level architectural changes (Layer Freezing, LoRA adaptation, Head Merging) to relocate expensive computations outside MPC, aided by knowledge distillation to preserve ML performance. Empirical results show 3.6–11.3x runtime and 2.4–6.9x lower communication across MPC settings, while maintaining roughly 90%+ of standard fine-tuning accuracy on tasks spanning code, chat, and translation. The approach is complementary to MPC-friendly approximations and benefits from open-source pre-trained weights, enabling practical privacy-preserving LLM services with broad applicability across secure inference protocols.

Abstract

Many inference services based on large language models (LLMs) pose a privacy concern, either revealing user prompts to the service or the proprietary weights to the user. Secure inference offers a solution to this problem through secure multi-party computation (MPC), however, it is still impractical for modern LLM workload due to the large overhead imposed by MPC. To address this overhead, we propose Marill, a framework that adapts LLM fine-tuning to minimize MPC usage during secure inference. Marill introduces high-level architectural changes during fine-tuning that significantly reduce the number of expensive operations needed within MPC during inference, by removing some and relocating others outside MPC without compromising security. As a result, Marill-generated models are more efficient across all secure inference protocols and our approach complements MPC-friendly approximations for such operations. Compared to standard fine-tuning, Marill results in 3.6-11.3x better runtime and 2.4-6.9x better communication during secure inference across various MPC settings, while typically preserving over 90% performance across downstream tasks.

MPC-Minimized Secure LLM Inference

TL;DR

The paper addresses privacy-preserving LLM inference by reducing MPC-related overhead through MPC-minimization. It introduces Marill, a fine-tuning framework that splits model weights into public and private components and applies high-level architectural changes (Layer Freezing, LoRA adaptation, Head Merging) to relocate expensive computations outside MPC, aided by knowledge distillation to preserve ML performance. Empirical results show 3.6–11.3x runtime and 2.4–6.9x lower communication across MPC settings, while maintaining roughly 90%+ of standard fine-tuning accuracy on tasks spanning code, chat, and translation. The approach is complementary to MPC-friendly approximations and benefits from open-source pre-trained weights, enabling practical privacy-preserving LLM services with broad applicability across secure inference protocols.

Abstract

Many inference services based on large language models (LLMs) pose a privacy concern, either revealing user prompts to the service or the proprietary weights to the user. Secure inference offers a solution to this problem through secure multi-party computation (MPC), however, it is still impractical for modern LLM workload due to the large overhead imposed by MPC. To address this overhead, we propose Marill, a framework that adapts LLM fine-tuning to minimize MPC usage during secure inference. Marill introduces high-level architectural changes during fine-tuning that significantly reduce the number of expensive operations needed within MPC during inference, by removing some and relocating others outside MPC without compromising security. As a result, Marill-generated models are more efficient across all secure inference protocols and our approach complements MPC-friendly approximations for such operations. Compared to standard fine-tuning, Marill results in 3.6-11.3x better runtime and 2.4-6.9x better communication during secure inference across various MPC settings, while typically preserving over 90% performance across downstream tasks.
Paper Structure (23 sections, 8 figures, 3 tables)

This paper contains 23 sections, 8 figures, 3 tables.

Figures (8)

  • Figure 1: End-to-end workflow of our system. The private and public components are highlighted in red and blue, respectively. The gray region represents our fine-tuning framework, Marill, that outputs an MPC-minimized inference model. Note that Marill differs from prior works such as MPCFormer mpcformer since they output a (fully) fine-tuned model after fine-tuning. Consequently, the inference phase (steps 3-5) in our system also differs from the prior works in two ways: (i) only a part of the inference model is private, and thus, only that part is fed to the MPC engine, and (ii) instead of directly feeding its private input, the client inputs the partial inference result of the model's public component on its private input. In the figure, we only show single token generation; subsequent tokens can be generated similarly since the client has access to all the tokens generated so far. Additionally, we only show two parties each running an MPC engine instance. Orthogonally, there is also an additional helper party in some protocols that helps speed up secure inference (\ref{['app:mpc-setting']}).
  • Figure 2: Marill's techniques that leverage public weights (marked in blue).
  • Figure 3: Head merging ($m = 2$) example for seq-len $b=3$, #heads $h=4$, and head-dim $d = 2$. After merging, $h$ reduces to $h' = 2$ and $d$ increases to $d' = 4$. The red matrices represent that head-merging is only performed in private layers.
  • Figure 4: Secure inference performance of Marill vs standard fine-tuning for openllama-3b-v2. The sequence length is set to $b=64$ for 2PC and $b=2048$ for 3PC and 2PC-Dealer. The numbers on the bars represent the improvement factor over the baseline. The final bar in each plot represents the combination of layer-freezing with head-merging or LoRA, whichever performs better independently.
  • Figure 5: Marill vs (fully) fine-tuned and zero-shot baselines.
  • ...and 3 more figures