Table of Contents
Fetching ...

Joint Fine-tuning and Conversion of Pretrained Speech and Language Models towards Linear Complexity

Mutian He, Philip N. Garner

TL;DR

This paper addresses the cost of deploying large pretrained transformers in non-text domains by enabling their conversion into linear-time substitutes through Cross-Architecture Layerwise Distillation (CALD) that jointly fine-tunes and converts models. CALD transfers parameters and distills behavior from a task-tuned teacher to a linear-time student (e.g., RoBERTa→Linformer, Pythia→Mamba, Wav2Vec2→Mamba2) and explores Target Guided, Trajectory Guided, Waypoint Guided, and Hybrid modes, using a combined loss over cross-entropy, output distillation, and hidden-state alignment. Across language processing, language modeling, and speech tasks, CALD substantially reduces performance gaps relative to unguided transfers and approaches or surpasses the baselines, with trajectory/hybrid strategies providing further gains depending on hidden-state drift during fine-tuning. The results demonstrate substantial practical benefits, including reduced pretraining costs (e.g., Wav2Vec2-large→Mamba2 ≈ 1.6 days on a single GPU versus full pretraining), enabling broader adoption of linear-time architectures for long-form speech and cross-domain tasks.

Abstract

Architectures such as Linformer and Mamba have recently emerged as competitive linear time replacements for transformers. However, corresponding large pretrained models are often unavailable, especially in non-text domains. To remedy this, we present a Cross-Architecture Layerwise Distillation (CALD) approach that jointly converts a transformer model to a linear time substitute and fine-tunes it to a target task. We also compare several means to guide the fine-tuning to optimally retain the desired inference capability from the original model. The methods differ in their use of the target model and the trajectory of the parameters. In a series of empirical studies on language processing, language modeling, and speech processing, we show that CALD can effectively recover the result of the original model, and that the guiding strategy contributes to the result. Some reasons for the variation are suggested.

Joint Fine-tuning and Conversion of Pretrained Speech and Language Models towards Linear Complexity

TL;DR

This paper addresses the cost of deploying large pretrained transformers in non-text domains by enabling their conversion into linear-time substitutes through Cross-Architecture Layerwise Distillation (CALD) that jointly fine-tunes and converts models. CALD transfers parameters and distills behavior from a task-tuned teacher to a linear-time student (e.g., RoBERTa→Linformer, Pythia→Mamba, Wav2Vec2→Mamba2) and explores Target Guided, Trajectory Guided, Waypoint Guided, and Hybrid modes, using a combined loss over cross-entropy, output distillation, and hidden-state alignment. Across language processing, language modeling, and speech tasks, CALD substantially reduces performance gaps relative to unguided transfers and approaches or surpasses the baselines, with trajectory/hybrid strategies providing further gains depending on hidden-state drift during fine-tuning. The results demonstrate substantial practical benefits, including reduced pretraining costs (e.g., Wav2Vec2-large→Mamba2 ≈ 1.6 days on a single GPU versus full pretraining), enabling broader adoption of linear-time architectures for long-form speech and cross-domain tasks.

Abstract

Architectures such as Linformer and Mamba have recently emerged as competitive linear time replacements for transformers. However, corresponding large pretrained models are often unavailable, especially in non-text domains. To remedy this, we present a Cross-Architecture Layerwise Distillation (CALD) approach that jointly converts a transformer model to a linear time substitute and fine-tunes it to a target task. We also compare several means to guide the fine-tuning to optimally retain the desired inference capability from the original model. The methods differ in their use of the target model and the trajectory of the parameters. In a series of empirical studies on language processing, language modeling, and speech processing, we show that CALD can effectively recover the result of the original model, and that the guiding strategy contributes to the result. Some reasons for the variation are suggested.

Paper Structure

This paper contains 15 sections, 2 equations, 5 figures, 4 tables, 3 algorithms.

Figures (5)

  • Figure 1: A conceptual illustration of the hidden states shift during training under different modes of distillation. Given the trajectory (green line) of the hidden states during the fine-tuning from the source teacher model (i.e. pretrained transformer) to the target teacher model (i.e. transformer fine-tuned on the target task), we consider: a) Unguided: parameter transfer without any distillation; b) Target Guided: distill from the target teacher; c) Trajectory/Waypoint Guided: gradually distill from a series of models on the trajectory; d) Hybrid: distill from the target teacher until a certain step.
  • Figure 2: Encoder layers in the student bidirectional Mamba2 model (b), converted from Wav2Vec2 encoder layers in the teacher model (a). In each encoder layer, the attention layer is replaced by two new forward and backward Mamba2 mixers. Inputs and outputs of the backward mixer are timewise inverted. Hidden states after the feed-forward layer are extracted for layerwise distillation.
  • Figure 3: Average cosine distance between the hidden states produced by the initial model and the checkpoints during fine-tuning after each epoch (for NLP tasks) or every 10,000 steps (for speech tasks). Unlike NLP models, features produced by the fine-tuned speech models are far from the initial ones since the early phase of fine-tuning.
  • Figure 4: The trajectory of hidden states during training under different modes of guidance, visualized using t-SNE.
  • Figure 5: Time costs for Mamba2 and transformer-based models performing ASR on audio samples with different lengths, averaged by 5 runs on a single RTX3090.