Learnable Sparsity for Vision Generative Models
Yang Zhang, Er Jin, Wenzhong Liang, Yanfei Dong, Ashkan Khakzar, Philip Torr, Johannes Stegmaier, Kenji Kawaguchi
TL;DR
Large vision generative systems are hampered by size and compute. EcoDiff introduces a differentiable, end-to-end pruning framework with time-step gradient checkpointing to learn a shared mask across denoising steps, enabling up to 20% sparsity without retraining. Structural masking targets attention heads and FFN neurons, using hard-concrete relaxation and L0 regularization to ensure deployable sparsity. The approach achieves strong semantic fidelity on SDXL and FLUX with a lightweight post-pruning retraining option (LoRA or full fine-tuning) and remains compatible with step-distilled variants, reducing resource demands for practical deployment.
Abstract
Diffusion models have achieved impressive advancements in various vision tasks. However, these gains often rely on increasing model size, which escalates computational complexity and memory demands, complicating deployment, raising inference costs, and causing environmental impact. While some studies have explored pruning techniques to improve the memory efficiency of diffusion models, most existing methods require extensive retraining to retain the model performance. Retraining a modern large diffusion model is extremely costly and resource-intensive, which limits the practicality of these methods. In this work, we achieve low-cost diffusion pruning without retraining by proposing a model-agnostic structural pruning framework for diffusion models that learns a differentiable mask to sparsify the model. To ensure effective pruning that preserves the quality of the final denoised latent, we design a novel end-to-end pruning objective that spans the entire diffusion process. As end-to-end pruning is memory-intensive, we further propose time step gradient checkpointing, a technique that significantly reduces memory usage during optimization, enabling end-to-end pruning within a limited memory budget. Results on state-of-the-art U-Net diffusion models SDXL and diffusion transformers (FLUX) demonstrate that our method can effectively prune up to 20% parameters with minimal perceptible performance degradation, and notably, without the need for model retraining. We also showcase that our method can still prune on top of time step distilled diffusion models.
