Table of Contents
Fetching ...

Coordinate Descent for Network Linearization

Vlad Rakhlin, Amir Jevnisek, Shai Avidan

TL;DR

This work tackles the discrete optimization problem of reducing ReLU activations to enable private inference with neural networks. It introduces a Block Coordinate Descent algorithm that directly operates on a binary ReLU mask, removing ReLUs iteratively and finetuning as needed to maintain accuracy. The method yields sparse networks with provable runtime-performance behavior and demonstrates state-of-the-art accuracy across ResNet18 and Wide-ResNet-22-8 on CIFAR-10, CIFAR-100, and TinyImageNet, often outperforming existing selective approaches and even enabling AutoRep-style performance at reduced budgets. By offering a drop-in discrete optimization that complements existing PI pipelines, it has practical implications for latency and bandwidth efficiency in privacy-preserving inference.

Abstract

ReLU activations are the main bottleneck in Private Inference that is based on ResNet networks. This is because they incur significant inference latency. Reducing ReLU count is a discrete optimization problem, and there are two common ways to approach it. Most current state-of-the-art methods are based on a smooth approximation that jointly optimizes network accuracy and ReLU budget at once. However, the last hard thresholding step of the optimization usually introduces a large performance loss. We take an alternative approach that works directly in the discrete domain by leveraging Coordinate Descent as our optimization framework. In contrast to previous methods, this yields a sparse solution by design. We demonstrate, through extensive experiments, that our method is State of the Art on common benchmarks.

Coordinate Descent for Network Linearization

TL;DR

This work tackles the discrete optimization problem of reducing ReLU activations to enable private inference with neural networks. It introduces a Block Coordinate Descent algorithm that directly operates on a binary ReLU mask, removing ReLUs iteratively and finetuning as needed to maintain accuracy. The method yields sparse networks with provable runtime-performance behavior and demonstrates state-of-the-art accuracy across ResNet18 and Wide-ResNet-22-8 on CIFAR-10, CIFAR-100, and TinyImageNet, often outperforming existing selective approaches and even enabling AutoRep-style performance at reduced budgets. By offering a drop-in discrete optimization that complements existing PI pipelines, it has practical implications for latency and bandwidth efficiency in privacy-preserving inference.

Abstract

ReLU activations are the main bottleneck in Private Inference that is based on ResNet networks. This is because they incur significant inference latency. Reducing ReLU count is a discrete optimization problem, and there are two common ways to approach it. Most current state-of-the-art methods are based on a smooth approximation that jointly optimizes network accuracy and ReLU budget at once. However, the last hard thresholding step of the optimization usually introduces a large performance loss. We take an alternative approach that works directly in the discrete domain by leveraging Coordinate Descent as our optimization framework. In contrast to previous methods, this yields a sparse solution by design. We demonstrate, through extensive experiments, that our method is State of the Art on common benchmarks.

Paper Structure

This paper contains 36 sections, 6 equations, 11 figures, 6 tables, 2 algorithms.

Figures (11)

  • Figure 1: Accuracy vs ReLU Budget for a ResNet18 Network: ResNet18 Accuracy [%] for classifying CIFAR-10, CIFAR-100 and TinyImageNet for different ReLU budgets [# ReLUs]. Our method achieves the best performance for every ReLU budget. For some budgets our method surpasses the baseline which is the original network containing 100% of the ReLUs.
  • Figure 2: ReLU Reduction Methods: The green ellipse denotes methods that replace ReLUs with identity functions (known as Network Linearization). The blue ellipse denotes methods that replace ReLUs with polynomial approximations (which are faster to compute than ReLUs in a PI setting). The color of the circles denotes the optimization approach used. Selective methods (brown circles) jointly optimize Cross-Entropy loss and a regularization term that takes care of adhering to a budget. NAS (purple circles) use Network Architecture Search for the optimization, while Manual (red circle) outlines a manual procedure for ReLU reduction. We use Block Coordinate Descent (star) that provides a sparse solution by design.
  • Figure 3: Ours vs SENet on a ResNet18 backbone: We show that our method achieves the Pareto frontier on network accuracy with respect to ReLU budget on CIFAR-100 and TinyImageNet, while staying competitive on CIFAR-10 against SENet and SENet++. We test using a metric which is agnostic to the baseline classifier accuracy. Specifically, we measure the ratio between the performance reached by executing each method and divide it by the accuracy of the baseline. Namely: $\frac{\textrm{Accuracy for Budget}}{\textrm{Baseline Accuracy}}$
  • Figure 4: Our method on top of AutoRep for CIFAR-100: (a) for a ResNet18 backbone, (b) for a WideResNet22-8 backbone. Our method can be plugged on top of any Selective method. Here we demonstrate the performance gain using our method on top of AutoRep. In the case of ResNet18, our method achieves 72.9% accuracy with 6K ReLUs, compared with AutoRep which achieves this performance with more than $\times 2$ the number of ReLUs (15K). The same trend occurs for WideResNet22-8 as well, where 76% accuracy is achieved in our case with half of the ReLU budget that this accuracy is achieved in the case of AutoRep.
  • Figure 5: Hyperparameters ablation study: We compare the effect of various hyperparameters on the performance of the algorithm on the ResNet18/CIFAR-100 setting. Chosen parameters are circled. (a) DRC is inverse proportional to the number of iteration steps $T = \lceil\frac{B_\textrm{ref} - B_\textrm{target}}{DRC}\rceil$. Accuracy decreases as the DRC increases. Conversely, the accuracy increases as the number of iteration steps $T$ increases, as anticipated in Equations \ref{['eq:sss_prob']} and \ref{['eq:sss']}. In (b) and (c) we evaluate the effect of the number of finetune epochs and the Accuracy Degradation Tolerance (ADT) on the test accuracy.
  • ...and 6 more figures