Table of Contents
Fetching ...

Fast Ensembling with Diffusion Schrödinger Bridge

Hyunsu Kim, Jongmin Yoon, Juho Lee

TL;DR

The paper tackles the high computational cost of Deep Ensembles by introducing Diffusion Bridge Networks (DBN) that learn a Schrödinger-bridge-based diffusion path between a single source model's logits and the ensemble's logits. By training a lightweight score network and distilling diffusion steps, DBN enables ensemble-like predictions with a fraction of the cost, while preserving accuracy and improving calibration. Empirical results on CIFAR-10/100 and TinyImageNet show DBN achieving competitive accuracy and uncertainty metrics with substantially reduced FLOPs and parameter counts compared to full ensembles and existing fast-ensemble methods. The approach is extensible via multiple diffusion bridges to scale ensemble capacity, though per-bridge limitations and training cost for multiple bridges remain considerations for deployment.

Abstract

Deep Ensemble (DE) approach is a straightforward technique used to enhance the performance of deep neural networks by training them from different initial points, converging towards various local optima. However, a limitation of this methodology lies in its high computational overhead for inference, arising from the necessity to store numerous learned parameters and execute individual forward passes for each parameter during the inference stage. We propose a novel approach called Diffusion Bridge Network (DBN) to address this challenge. Based on the theory of the Schrödinger bridge, this method directly learns to simulate an Stochastic Differential Equation (SDE) that connects the output distribution of a single ensemble member to the output distribution of the ensembled model, allowing us to obtain ensemble prediction without having to invoke forward pass through all the ensemble models. By substituting the heavy ensembles with this lightweight neural network constructing DBN, we achieved inference with reduced computational cost while maintaining accuracy and uncertainty scores on benchmark datasets such as CIFAR-10, CIFAR-100, and TinyImageNet. Our implementation is available at https://github.com/kim-hyunsu/dbn.

Fast Ensembling with Diffusion Schrödinger Bridge

TL;DR

The paper tackles the high computational cost of Deep Ensembles by introducing Diffusion Bridge Networks (DBN) that learn a Schrödinger-bridge-based diffusion path between a single source model's logits and the ensemble's logits. By training a lightweight score network and distilling diffusion steps, DBN enables ensemble-like predictions with a fraction of the cost, while preserving accuracy and improving calibration. Empirical results on CIFAR-10/100 and TinyImageNet show DBN achieving competitive accuracy and uncertainty metrics with substantially reduced FLOPs and parameter counts compared to full ensembles and existing fast-ensemble methods. The approach is extensible via multiple diffusion bridges to scale ensemble capacity, though per-bridge limitations and training cost for multiple bridges remain considerations for deployment.

Abstract

Deep Ensemble (DE) approach is a straightforward technique used to enhance the performance of deep neural networks by training them from different initial points, converging towards various local optima. However, a limitation of this methodology lies in its high computational overhead for inference, arising from the necessity to store numerous learned parameters and execute individual forward passes for each parameter during the inference stage. We propose a novel approach called Diffusion Bridge Network (DBN) to address this challenge. Based on the theory of the Schrödinger bridge, this method directly learns to simulate an Stochastic Differential Equation (SDE) that connects the output distribution of a single ensemble member to the output distribution of the ensembled model, allowing us to obtain ensemble prediction without having to invoke forward pass through all the ensemble models. By substituting the heavy ensembles with this lightweight neural network constructing DBN, we achieved inference with reduced computational cost while maintaining accuracy and uncertainty scores on benchmark datasets such as CIFAR-10, CIFAR-100, and TinyImageNet. Our implementation is available at https://github.com/kim-hyunsu/dbn.
Paper Structure (43 sections, 13 equations, 6 figures, 7 tables, 1 algorithm)

This paper contains 43 sections, 13 equations, 6 figures, 7 tables, 1 algorithm.

Figures (6)

  • Figure 1: Overview of . For a given data, the conditional diffusion bridge learns a transition between logit distribution of one of the ensembles (left; source) and that of the target ensemble models (right; target).
  • Figure 2: Confidences from the source model (first column), from the ensemble model (third column), and from the diffusion bridge (middle column) in the CIFAR-10 dataset. The middle column illustrates a transition of the diffusion process.
  • Figure 3: The number of teachers that a single model can distill in terms of ACC (left) and DEE (right). DEEs less than 1 are set to 0.5.
  • Figure 4: From left (1$^\text{st}$) to right (4$^\text{th}$): Number of target ensembles vs. accuracy (1$^\text{st}$) or DEE (2$^\text{nd}$), and relative FLOPs vs. accuracy (3$^\text{rd}$) or DEE (4$^\text{th}$) in CIFAR-10.
  • Figure 5: Score network architecture.
  • ...and 1 more figures