Table of Contents
Fetching ...

TruncFormer: Private LLM Inference Using Only Truncations

Patrick Yubeaton, Jianqiao Cambridge Mo, Karthik Garimella, Nandan Kumar Jha, Brandon Reagen, Chinmay Hegde, Siddharth Garg

TL;DR

The paper addresses the impractical latency of private inference for LLMs by arguing that out-of-field nonlinearities dominate latency. It introduces TruncFormer, a framework that represents all LLM computations with in-field additions/multiplications and statically placed truncations, eliminating the need for truncation after every operation. By fixed-point encoding and adapted Crypten approximations, it achieves significant latency reductions while preserving plaintext-like accuracy, demonstrated on Llama-7B and Gemma-2B and made available as open source. Despite the gains, private inference remains slow in absolute terms, underscoring the need to further optimize truncations as the primary latency bottleneck and to extend these ideas across protocols.

Abstract

Private inference (PI) serves an important role in guaranteeing the privacy of user data when interfacing with proprietary machine learning models such as LLMs. However, PI remains practically intractable due to the massive latency costs associated with nonlinear functions present in LLMs. Existing works have focused on improving latency of specific LLM nonlinearities (such as the Softmax, or the GeLU) via approximations. However, new types of nonlinearities are regularly introduced with new LLM architectures, and this has led to a constant game of catch-up where PI researchers attempt to optimize the newest nonlinear function. We introduce TruncFormer, a framework for taking any LLM and transforming it into a plaintext emulation of PI. Our framework leverages the fact that nonlinearities in LLMs are differentiable and can be accurately approximated with a sequence of additions, multiplications, and truncations. Further, we decouple the add/multiply and truncation operations, and statically determine where truncations should be inserted based on a given field size and input representation size. This leads to latency improvements over existing cryptographic protocols that enforce truncation after every multiplication operation. We open source our code for community use.

TruncFormer: Private LLM Inference Using Only Truncations

TL;DR

The paper addresses the impractical latency of private inference for LLMs by arguing that out-of-field nonlinearities dominate latency. It introduces TruncFormer, a framework that represents all LLM computations with in-field additions/multiplications and statically placed truncations, eliminating the need for truncation after every operation. By fixed-point encoding and adapted Crypten approximations, it achieves significant latency reductions while preserving plaintext-like accuracy, demonstrated on Llama-7B and Gemma-2B and made available as open source. Despite the gains, private inference remains slow in absolute terms, underscoring the need to further optimize truncations as the primary latency bottleneck and to extend these ideas across protocols.

Abstract

Private inference (PI) serves an important role in guaranteeing the privacy of user data when interfacing with proprietary machine learning models such as LLMs. However, PI remains practically intractable due to the massive latency costs associated with nonlinear functions present in LLMs. Existing works have focused on improving latency of specific LLM nonlinearities (such as the Softmax, or the GeLU) via approximations. However, new types of nonlinearities are regularly introduced with new LLM architectures, and this has led to a constant game of catch-up where PI researchers attempt to optimize the newest nonlinear function. We introduce TruncFormer, a framework for taking any LLM and transforming it into a plaintext emulation of PI. Our framework leverages the fact that nonlinearities in LLMs are differentiable and can be accurately approximated with a sequence of additions, multiplications, and truncations. Further, we decouple the add/multiply and truncation operations, and statically determine where truncations should be inserted based on a given field size and input representation size. This leads to latency improvements over existing cryptographic protocols that enforce truncation after every multiplication operation. We open source our code for community use.

Paper Structure

This paper contains 21 sections, 1 equation, 3 figures, 9 tables, 2 algorithms.

Figures (3)

  • Figure 1: Newton Raphson flowchart for the approximation of the inverse squareroot. Standard private inference protocols make use of a truncation after every operation which we show is excessive.
  • Figure 2: We compare the impact of input length and token generation count on private inference latency.
  • Figure 3: We approximate the inverse-sqrt/reciprocal of the range of numbers (0, 10] with intervals of 0.1. The plots showcase the error between these approximations and the true function as defined by PyTorch. We vary the number of Newton-Raphson iterations for different plotted lines.