Table of Contents
Fetching ...

ElastiFormer: Learned Redundancy Reduction in Transformer via Self-Distillation

Junzhang Liu, Tingkai Liu, Yueyuan Sui, Stephen Xia

TL;DR

ElastiFormer, a post-training technique that adapts pretrained Transformer models into an elastic counterpart with variable inference time compute, is introduced and it is shown that ElastiFormer is robust against the training domain.

Abstract

We introduce ElastiFormer, a post-training technique that adapts pretrained Transformer models into an elastic counterpart with variable inference time compute. ElastiFormer introduces small routing modules (as low as .00006% additional trainable parameters) to dynamically selects subsets of network parameters and input tokens to be processed by each layer of the pretrained network in an inputdependent manner. The routing modules are trained using self-distillation losses to minimize the differences between the output of the pretrained-model and their elastic counterparts. As ElastiFormer makes no assumption regarding the modality of the pretrained Transformer model, it can be readily applied to all modalities covering causal language modeling, image modeling as well as visual-language modeling tasks. We show that 20% to 50% compute saving could be achieved for different components of the transformer layer, which could be further reduced by adding very low rank LoRA weights (rank 1) trained via the same distillation objective. Finally, by comparing routing trained on different subsets of ImageNet, we show that ElastiFormer is robust against the training domain.

ElastiFormer: Learned Redundancy Reduction in Transformer via Self-Distillation

TL;DR

ElastiFormer, a post-training technique that adapts pretrained Transformer models into an elastic counterpart with variable inference time compute, is introduced and it is shown that ElastiFormer is robust against the training domain.

Abstract

We introduce ElastiFormer, a post-training technique that adapts pretrained Transformer models into an elastic counterpart with variable inference time compute. ElastiFormer introduces small routing modules (as low as .00006% additional trainable parameters) to dynamically selects subsets of network parameters and input tokens to be processed by each layer of the pretrained network in an inputdependent manner. The routing modules are trained using self-distillation losses to minimize the differences between the output of the pretrained-model and their elastic counterparts. As ElastiFormer makes no assumption regarding the modality of the pretrained Transformer model, it can be readily applied to all modalities covering causal language modeling, image modeling as well as visual-language modeling tasks. We show that 20% to 50% compute saving could be achieved for different components of the transformer layer, which could be further reduced by adding very low rank LoRA weights (rank 1) trained via the same distillation objective. Finally, by comparing routing trained on different subsets of ImageNet, we show that ElastiFormer is robust against the training domain.

Paper Structure

This paper contains 35 sections, 5 equations, 12 figures, 1 table, 2 algorithms.

Figures (12)

  • Figure 1: Overview of ElastiFormer for language, visual, and multi-modal transformers. (Left) Illustration of learned routing modules around Multi-Head Attention (MHA) and Multi-layer Perceptron (MLP) modules of a pretrained transformer model. (Middle) Illustration of learned routing modules inside MLP and MHA modules, and learned routing that selects a subset of image tokens that provide multi-modal input to language decoder in VLMs. (Right) Illustration of self-distillation training objectives across modalities. Note that for Visual Transformers (ViT), the example provided here is from Masked Auto Encoding (MAE) ViT.
  • Figure 2: Difference in language modeling loss (blue) and top-1 token prediction agreement (red) between pretrained Gemma2 model and Gemma2 model with skipped MLP layers (left) or attention heads in MHA (right). Experiments are performed for both GSM8K (solid line) and HumanEval (dashed line) datasets.
  • Figure 3: Illustration of the two subset selection schemes employed in the current work. For parameter subset selection, the sub-modules (Mod 1-4) that are selected by the routing scheme can either refer to the attention heads in MHA or experts in MLP. Note that to create experts in MLP from a pretrained dense MLP layer, we first transform the dense MLP parameters to block matrices that form the experts. In either routing scheme, the routing weights are multiplied with the output to ensure gradient flow.
  • Figure 4: Comparison between different distillation losses for language output modality. The three types of variations of the KL-divergence objectives are illustrated on the bottom row.
  • Figure 5: Scaling of various modules in Elasti-LLM against compute for elastic Phi-3.5-mini-instruct. LM Loss of the pretrained teacher model is shown as the horizontal black line in each of the subfigures.
  • ...and 7 more figures