ReDistill: Residual Encoded Distillation for Peak Memory Reduction of CNNs
Fang Chen, Gourav Datta, Mujahid Al Rafi, Hyeran Jeon, Meng Tang
TL;DR
ReDistill addresses the challenge of high peak memory in CNN inference on memory-constrained edge devices by pairing aggressive initial pooling in a student network with Residual Encoded Distillation (RED) blocks that align the student’s down-sampled features with the teacher’s features. The RED block combines a gating mechanism and a residual encoder to efficiently transfer knowledge while bounding memory usage, enabling substantial memory reductions with minimal performance loss on image classification and diffusion-based image generation. Across extensive experiments, ReDistill outperforms existing KD methods in the memory-accuracy trade-off, achieving roughly four-to-fivefold peak-memory reductions for classification and about fourfold reductions for DDPM-based image generation, with practical deployment potential on edge hardware. The work provides a versatile, memory-centric distillation framework that can be integrated with existing KD strategies and quantization techniques, facilitating scalable edge deployment and paving the way for future extensions to other architectures such as vision transformers.
Abstract
The expansion of neural network sizes and the enhanced resolution of modern image sensors result in heightened memory and power demands to process modern computer vision models. In order to deploy these models in extremely resource-constrained edge devices, it is crucial to reduce their peak memory, which is the maximum memory consumed during the execution of a model. A naive approach to reducing peak memory is aggressive down-sampling of feature maps via pooling with large stride, which often results in unacceptable degradation in network performance. To mitigate this problem, we propose residual encoded distillation (ReDistill) for peak memory reduction in a teacher-student framework, in which a student network with less memory is derived from the teacher network using aggressive pooling. We apply our distillation method to multiple problems in computer vision, including image classification and diffusion-based image generation. For image classification, our method yields 4x-5x theoretical peak memory reduction with less degradation in accuracy for most CNN-based architectures. For diffusion-based image generation, our proposed distillation method yields a denoising network with 4x lower theoretical peak memory while maintaining decent diversity and fidelity for image generation. Experiments demonstrate our method's superior performance compared to other feature-based and response-based distillation methods when applied to the same student network. The code is available at https://github.com/mengtang-lab/ReDistill.
