Table of Contents
Fetching ...

RAD: Redundancy-Aware Distillation for Hybrid Models via Self-Speculative Decoding

Yuichiro Hoshino, Hideyuki Tachibana, Muneyoshi Inahara, Hiroto Takegawa

TL;DR

RAD tackles inefficiency in hybrid Transformer-SSM models by identifying redundant attention layers through self-speculative decoding and replacing them with efficient SSM blocks. A Bayesian-optimized, redundancy-aware distillation pipeline then trains only the added SSM components, enabling a Born-Again-like improvement on reasoning tasks and faster convergence in standard distillation. Empirical results on Llama3.2-3B-Instruct show significant gains on GSM8K and CRUX with an 8B teacher and strong performance across general and long-context benchmarks, illustrating the practical viability of selective architectural refinement. By combining targeted layer replacement with focused distillation, RAD offers a scalable path to efficient, high-performance hybrid models for challenging tasks.

Abstract

Hybrid models combining Transformers and State Space Models (SSMs) are promising for balancing performance and efficiency. However, optimizing these hybrid models, particularly by addressing the potential redundancy inherent within the Transformer components, remains a significant challenge. In this paper, we propose RAD (Redundancy-Aware Distillation), a novel framework that uses self-speculative decoding as a diagnostic tool to identify redundant attention layers within the model. These identified layers are then selectively replaced with SSM components, followed by targeted (self-)distillation. Specifically, RAD focuses knowledge transfer on the components identified as redundant, considering architectural changes and specific weight initialization strategies. We experimentally demonstrate that self-distillation using RAD significantly surpasses the performance of the original base model on mathematical and coding tasks. Furthermore, RAD is also effective in standard knowledge distillation settings, achieving up to approximately 2x faster convergence compared to baseline methods. Notably, while a baseline model distilled from a Llama-3.1 70B teacher achieves scores of 46.17 on GSM8K and 22.75 on CRUX, RAD achieves significantly higher scores of 71.27 on GSM8K and 28.25 on CRUX, even when using a much smaller Llama-3.1 8B teacher. RAD offers a new pathway for efficient optimization and performance enhancement in the distillation of hybrid models.

RAD: Redundancy-Aware Distillation for Hybrid Models via Self-Speculative Decoding

TL;DR

RAD tackles inefficiency in hybrid Transformer-SSM models by identifying redundant attention layers through self-speculative decoding and replacing them with efficient SSM blocks. A Bayesian-optimized, redundancy-aware distillation pipeline then trains only the added SSM components, enabling a Born-Again-like improvement on reasoning tasks and faster convergence in standard distillation. Empirical results on Llama3.2-3B-Instruct show significant gains on GSM8K and CRUX with an 8B teacher and strong performance across general and long-context benchmarks, illustrating the practical viability of selective architectural refinement. By combining targeted layer replacement with focused distillation, RAD offers a scalable path to efficient, high-performance hybrid models for challenging tasks.

Abstract

Hybrid models combining Transformers and State Space Models (SSMs) are promising for balancing performance and efficiency. However, optimizing these hybrid models, particularly by addressing the potential redundancy inherent within the Transformer components, remains a significant challenge. In this paper, we propose RAD (Redundancy-Aware Distillation), a novel framework that uses self-speculative decoding as a diagnostic tool to identify redundant attention layers within the model. These identified layers are then selectively replaced with SSM components, followed by targeted (self-)distillation. Specifically, RAD focuses knowledge transfer on the components identified as redundant, considering architectural changes and specific weight initialization strategies. We experimentally demonstrate that self-distillation using RAD significantly surpasses the performance of the original base model on mathematical and coding tasks. Furthermore, RAD is also effective in standard knowledge distillation settings, achieving up to approximately 2x faster convergence compared to baseline methods. Notably, while a baseline model distilled from a Llama-3.1 70B teacher achieves scores of 46.17 on GSM8K and 22.75 on CRUX, RAD achieves significantly higher scores of 71.27 on GSM8K and 28.25 on CRUX, even when using a much smaller Llama-3.1 8B teacher. RAD offers a new pathway for efficient optimization and performance enhancement in the distillation of hybrid models.

Paper Structure

This paper contains 51 sections, 22 equations, 9 figures, 9 tables, 1 algorithm.

Figures (9)

  • Figure 1: Overview of our proposed RAD (Redundancy-Aware Distillation) framework. (a) Redundancy Identification: Redundant attention layers are identified via self-speculative decoding (selectively skipping attention layers) and Bayesian Optimization aimed at maximizing the resulting inference throughput. (b) Hybrid Model Initialization: The identified redundant attention layers are replaced with SSM blocks (e.g., Mamba2) using specific weight initialization strategies (copying 'out_proj' weights from the attention block and zero-initializing of 'in_proj' weights) to create the initial hybrid model $\mathcal{M}_{\textit{hyb}}$. (c) Redundancy-Aware Distillation: Knowledge is distilled from the teacher model $\mathcal{M}_{p}$ to the student hybrid model $\mathcal{M}_{\textit{hyb}}$ by training only the parameters of the newly added SSM blocks, using forward KL divergence on the final output logits.
  • Figure 2: Comparison of training loss curves with baseline proposed in junxiongdaniele2024mambainllama for 50% Mamba2 distillation.
  • Figure 3: Comparison of training loss curves for self-distillation setting. opt: high throughput attn. layers, worse: low throughput attn. layers.
  • Figure 4: Comparison of self-distillation (50% layers replaced with Mamba2) loss performance with and without zero-initialization of 'in_proj' in the SSM block. Optimal, worse, equal interval layer selection are also shown.
  • Figure 5: Comparison of training loss curves in the self-distillation setting, with 8 layers replaced in all models.
  • ...and 4 more figures