Table of Contents
Fetching ...

Linearizing Models for Efficient yet Robust Private Inference

Sreetama Sarkar, Souvik Kundu, Peter A. Beerel

TL;DR

This work tackles the latency bottleneck of private inference by reducing ReLU non-linearities through RLNet, a robust linearized network with a shared-mask, shared-weight design. It presents a three-stage training pipeline—robust teacher training, ReLU mask search, and three-way robust distillation—coupled with data augmentation and dual Batch Normalization to maintain accuracy on clean, naturally perturbed, and adversarial inputs. Empirical results on CIFAR-10/100 and Tiny-ImageNet across ResNet and WRN variants show RLNet achieves up to 11.14× fewer ReLUs with minimal loss in CA and NPA and substantial gains in AdvA compared to non-robust linearized baselines and SeNet. The approach thus enables practical, latency-efficient private inference with strong multi-domain robustness, with potential future work extending to vision transformers.

Abstract

The growing concern about data privacy has led to the development of private inference (PI) frameworks in client-server applications which protects both data privacy and model IP. However, the cryptographic primitives required yield significant latency overhead which limits its wide-spread application. At the same time, changing environments demand the PI service to be robust against various naturally occurring and gradient-based perturbations. Despite several works focused on the development of latency-efficient models suitable for PI, the impact of these models on robustness has remained unexplored. Towards this goal, this paper presents RLNet, a class of robust linearized networks that can yield latency improvement via reduction of high-latency ReLU operations while improving the model performance on both clean and corrupted images. In particular, RLNet models provide a "triple win ticket" of improved classification accuracy on clean, naturally perturbed, and gradient-based perturbed images using a shared-mask shared-weight architecture with over an order of magnitude fewer ReLUs than baseline models. To demonstrate the efficacy of RLNet, we perform extensive experiments with ResNet and WRN model variants on CIFAR-10, CIFAR-100, and Tiny-ImageNet datasets. Our experimental evaluations show that RLNet can yield models with up to 11.14x fewer ReLUs, with accuracy close to the all-ReLU models, on clean, naturally perturbed, and gradient-based perturbed images. Compared with the SoTA non-robust linearized models at similar ReLU budgets, RLNet achieves an improvement in adversarial accuracy of up to ~47%, naturally perturbed accuracy up to ~16.4%, while improving clean image accuracy up to ~1.5%.

Linearizing Models for Efficient yet Robust Private Inference

TL;DR

This work tackles the latency bottleneck of private inference by reducing ReLU non-linearities through RLNet, a robust linearized network with a shared-mask, shared-weight design. It presents a three-stage training pipeline—robust teacher training, ReLU mask search, and three-way robust distillation—coupled with data augmentation and dual Batch Normalization to maintain accuracy on clean, naturally perturbed, and adversarial inputs. Empirical results on CIFAR-10/100 and Tiny-ImageNet across ResNet and WRN variants show RLNet achieves up to 11.14× fewer ReLUs with minimal loss in CA and NPA and substantial gains in AdvA compared to non-robust linearized baselines and SeNet. The approach thus enables practical, latency-efficient private inference with strong multi-domain robustness, with potential future work extending to vision transformers.

Abstract

The growing concern about data privacy has led to the development of private inference (PI) frameworks in client-server applications which protects both data privacy and model IP. However, the cryptographic primitives required yield significant latency overhead which limits its wide-spread application. At the same time, changing environments demand the PI service to be robust against various naturally occurring and gradient-based perturbations. Despite several works focused on the development of latency-efficient models suitable for PI, the impact of these models on robustness has remained unexplored. Towards this goal, this paper presents RLNet, a class of robust linearized networks that can yield latency improvement via reduction of high-latency ReLU operations while improving the model performance on both clean and corrupted images. In particular, RLNet models provide a "triple win ticket" of improved classification accuracy on clean, naturally perturbed, and gradient-based perturbed images using a shared-mask shared-weight architecture with over an order of magnitude fewer ReLUs than baseline models. To demonstrate the efficacy of RLNet, we perform extensive experiments with ResNet and WRN model variants on CIFAR-10, CIFAR-100, and Tiny-ImageNet datasets. Our experimental evaluations show that RLNet can yield models with up to 11.14x fewer ReLUs, with accuracy close to the all-ReLU models, on clean, naturally perturbed, and gradient-based perturbed images. Compared with the SoTA non-robust linearized models at similar ReLU budgets, RLNet achieves an improvement in adversarial accuracy of up to ~47%, naturally perturbed accuracy up to ~16.4%, while improving clean image accuracy up to ~1.5%.
Paper Structure (23 sections, 5 equations, 6 figures, 6 tables)

This paper contains 23 sections, 5 equations, 6 figures, 6 tables.

Figures (6)

  • Figure 1: Clean, naturally-perturbed, and adversarial accuracy of ResNet18 for non-robust linearized models (left) and RLNet (right) on CIFAR-10 for different ReLU budget. SoTA linearized models lack robustness against natural and adversarial perturbations, while RLNet performs well on all three fronts outperforming its non-robust counterpart even in clean accuracy. (Note: all axes are not on same scale.)
  • Figure 2: Running mean and variance of the last BN layer of ResNet18 trained on Tiny-ImageNet using dual BN
  • Figure 3: RLNet framework: model architecture and inference path. Here, $\lambda=0$ and $1$ correspond to clean and adversarial paths respectively. We use the same path as clean for classifying against natural perturbations. This means that unless the perturbation is attacker-driven, we use the $\lambda=0$ path for inference.
  • Figure 4: CA, NPA, and AdvA for ResNet18 on CIFAR-10 for different training modes
  • Figure 5: (a) Dual vs triple BN for ResNet18 on CIFAR-10; (b) CA, NPA, and AdvA of RLNet vs separate ResNet18 PR models (ReLU count=82k) trained using standard, Augmix Hendrycks2019AugMixAS and PGDAT Madry2017TowardsDL.
  • ...and 1 more figures