Table of Contents
Fetching ...

Variational Distillation of Diffusion Policies into Mixture of Experts

Hongyi Zhou, Denis Blessing, Ge Li, Onur Celik, Xiaogang Jia, Gerhard Neumann, Rudolf Lioutikov

TL;DR

Variational Diffusion Distillation (VDD), a novel method that distills denoising diffusion policies into Mixtures of Experts (MoE) through variational inference, leverages a decompositional upper bound of the variational objective that allows the training of each expert separately, resulting in a robust optimization scheme for MoEs.

Abstract

This work introduces Variational Diffusion Distillation (VDD), a novel method that distills denoising diffusion policies into Mixtures of Experts (MoE) through variational inference. Diffusion Models are the current state-of-the-art in generative modeling due to their exceptional ability to accurately learn and represent complex, multi-modal distributions. This ability allows Diffusion Models to replicate the inherent diversity in human behavior, making them the preferred models in behavior learning such as Learning from Human Demonstrations (LfD). However, diffusion models come with some drawbacks, including the intractability of likelihoods and long inference times due to their iterative sampling process. The inference times, in particular, pose a significant challenge to real-time applications such as robot control. In contrast, MoEs effectively address the aforementioned issues while retaining the ability to represent complex distributions but are notoriously difficult to train. VDD is the first method that distills pre-trained diffusion models into MoE models, and hence, combines the expressiveness of Diffusion Models with the benefits of Mixture Models. Specifically, VDD leverages a decompositional upper bound of the variational objective that allows the training of each expert separately, resulting in a robust optimization scheme for MoEs. VDD demonstrates across nine complex behavior learning tasks, that it is able to: i) accurately distill complex distributions learned by the diffusion model, ii) outperform existing state-of-the-art distillation methods, and iii) surpass conventional methods for training MoE.

Variational Distillation of Diffusion Policies into Mixture of Experts

TL;DR

Variational Diffusion Distillation (VDD), a novel method that distills denoising diffusion policies into Mixtures of Experts (MoE) through variational inference, leverages a decompositional upper bound of the variational objective that allows the training of each expert separately, resulting in a robust optimization scheme for MoEs.

Abstract

This work introduces Variational Diffusion Distillation (VDD), a novel method that distills denoising diffusion policies into Mixtures of Experts (MoE) through variational inference. Diffusion Models are the current state-of-the-art in generative modeling due to their exceptional ability to accurately learn and represent complex, multi-modal distributions. This ability allows Diffusion Models to replicate the inherent diversity in human behavior, making them the preferred models in behavior learning such as Learning from Human Demonstrations (LfD). However, diffusion models come with some drawbacks, including the intractability of likelihoods and long inference times due to their iterative sampling process. The inference times, in particular, pose a significant challenge to real-time applications such as robot control. In contrast, MoEs effectively address the aforementioned issues while retaining the ability to represent complex distributions but are notoriously difficult to train. VDD is the first method that distills pre-trained diffusion models into MoE models, and hence, combines the expressiveness of Diffusion Models with the benefits of Mixture Models. Specifically, VDD leverages a decompositional upper bound of the variational objective that allows the training of each expert separately, resulting in a robust optimization scheme for MoEs. VDD demonstrates across nine complex behavior learning tasks, that it is able to: i) accurately distill complex distributions learned by the diffusion model, ii) outperform existing state-of-the-art distillation methods, and iii) surpass conventional methods for training MoE.
Paper Structure (40 sections, 37 equations, 7 figures, 7 tables, 1 algorithm)

This paper contains 40 sections, 37 equations, 7 figures, 7 tables, 1 algorithm.

Figures (7)

  • Figure 1: VDD distills a diffusion policy into an MoE. LfD is challenging due to the multimodality of human behaviour. For example, tele-operated demonstrations of an avoiding task often contain multiple solutions jia2024towards. Lower: A diffusion policy can predict high quality actions but relies on an iterative sampling process from noise to data, shown as the red arrows. Upper: VDD uses the score function to distill a diffusion policy into an MoE, unifying the advantages of both approaches.
  • Figure 2: Illustration of training VDD using the score function for a fixed state in a 2D toy task. (a) The probability density of the distribution is depicted by the color map. The score function is shown by the gradient field, visualized as white arrows. From (b) to (f), we initialize and train VDD until convergence. We initialize 8 components, each represented by an orange circle. These components are driven by the score function to match the data distribution and avoid overlapping modes by utilizing the learning objective in Eq. (\ref{['eq:cmp_update_obj']}). Eventually, they align with all data modes.
  • Figure 3: Ablation studies for key design choices used in method. (a) Using only one expert leads to a higher success rate but is unable to solve the task in diverse manners. Sufficiently more experts can trade off task success and action diversities. (b)Learning the gating distribution improves the success rates in three D3IL tasks. (c) A Uniform gating leads to higher task entropy in three out of two tasks. (d) Sampling the score from multiple noise levels leads to a better distillation performance
  • Figure 4: Trajectory visualization for VDD with different number of components $Z \in \{1,2,4,8\}$ on the Avoiding task (left). Different colors indicate components with highest likelihood according to the learned gating network $q^{\xi}(z|\mathbf{s})$ at a state $\mathbf{s}$. For each step we select the action by first sampling an expert from the categorical gating distribution and then take the mean of the expert prediction. We decompose the case $Z=8$ and visualize the individual experts $z_i$ (bottom row). Diverse behavior emerges as multiple actions are likely given the same state. For example, moving to the bottom right ($z_1$) and top right ($z_2$). An extreme case of losing diversity is seen with $Z=1$, where the policy is unable to capture the diverse behavior of the diffusion teacher, leading to deterministic trajectories.
  • Figure 5: method training
  • ...and 2 more figures