Table of Contents
Fetching ...

Tuning Language Models by Mixture-of-Depths Ensemble

Haoyan Luo, Lucia Specia

TL;DR

This work introduces a novel tuning framework, Mixture-of-Depths (MoD), which trains late layers as ensembles contributing to the final logits through learned routing weights, demonstrating the potential of leveraging predictive power from intermediate representations during training.

Abstract

Transformer-based Large Language Models (LLMs) traditionally rely on final-layer loss for training and final-layer representations for predictions, potentially overlooking the predictive power embedded in intermediate layers. Surprisingly, we find that focusing training efforts on these intermediate layers can yield training losses comparable to those of final layers, with complementary test-time performance. We introduce a novel tuning framework, Mixture-of-Depths (MoD), which trains late layers as ensembles contributing to the final logits through learned routing weights. With the auxiliary distillation loss and additional normalization modules, we ensure that the outputs of the late layers adapt to language modeling. Our MoD framework, which can be integrated with any existing tuning method, shows consistent improvement on various language modelling tasks. Furthermore, by replacing traditional trainable modules with MoD, our approach achieves similar performance with significantly fewer trainable parameters, demonstrating the potential of leveraging predictive power from intermediate representations during training.

Tuning Language Models by Mixture-of-Depths Ensemble

TL;DR

This work introduces a novel tuning framework, Mixture-of-Depths (MoD), which trains late layers as ensembles contributing to the final logits through learned routing weights, demonstrating the potential of leveraging predictive power from intermediate representations during training.

Abstract

Transformer-based Large Language Models (LLMs) traditionally rely on final-layer loss for training and final-layer representations for predictions, potentially overlooking the predictive power embedded in intermediate layers. Surprisingly, we find that focusing training efforts on these intermediate layers can yield training losses comparable to those of final layers, with complementary test-time performance. We introduce a novel tuning framework, Mixture-of-Depths (MoD), which trains late layers as ensembles contributing to the final logits through learned routing weights. With the auxiliary distillation loss and additional normalization modules, we ensure that the outputs of the late layers adapt to language modeling. Our MoD framework, which can be integrated with any existing tuning method, shows consistent improvement on various language modelling tasks. Furthermore, by replacing traditional trainable modules with MoD, our approach achieves similar performance with significantly fewer trainable parameters, demonstrating the potential of leveraging predictive power from intermediate representations during training.

Paper Structure

This paper contains 23 sections, 6 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: Tuning loss curves for LLaMA2-7B llama2 on ARC dataset arc. Above shows the loss curve of late layers when optimizing the loss based on the last layer output when late layers are optimized implicitly; Below shows the loss curves when optimizing the loss on each late layer output with the distillation loss $\mathcal{L}_{distill}$ w.r.t. the last layer.
  • Figure 2: Intersection of solved problems by tuning loss layers on the AQuA aqua, ARC-Challenge arc, and GSM8K gsm8k datasets. The digits in the Venn diagram illustrate the number of overlapping solved problems and the complementary solved problems for each method.
  • Figure 3: The overall framework of Mixture-of-Depths (MoD), which can be applied on top of any tuning method like LoRA hu2022lora. Given a pre-trained LLM and a tuning dataset, MoD applies trainable normalization $\mathcal{N}_k$ and pre-trained language model heads $\phi(\cdot)$ to the last $k$ layers $\{L_{n-k+1}, \ldots, L_n\}$. Each layer's output is combined using learned routing weights to produce the final logits. During training, a auxiliary teacher-enforced distillation loss $\mathcal{L}_{distill}$ is applied, where the final layer output serves as the teacher. MoD utilizes the ensemble logits during inference.
  • Figure 4: Sparsity scores for MoD (left) and MoD trained with $k$ LoRA layers (right). The curve is smoothed using moving average smoothing.
  • Figure 5: Accuracy scores for different $k$ ensemble layer ranges and Top-K sparse routing values. Lighter colors indicate better performance.
  • ...and 3 more figures