Table of Contents
Fetching ...

Faster Speech-LLaMA Inference with Multi-token Prediction

Desh Raj, Gil Keren, Junteng Jia, Jay Mahadeokar, Ozlem Kalinli

TL;DR

This work proposes to speed up Speech-LLaMA inference by predicting multiple tokens in the same decoding step by proposing a prefix-based beam search decoding method that allows efficient minimum word error rate (MWER) training for such models.

Abstract

Large language models (LLMs) have become proficient at solving a wide variety of tasks, including those involving multi-modal inputs. In particular, instantiating an LLM (such as LLaMA) with a speech encoder and training it on paired data imparts speech recognition (ASR) abilities to the decoder-only model, hence called Speech-LLaMA. Nevertheless, due to the sequential nature of auto-regressive inference and the relatively large decoder, Speech-LLaMA models require relatively high inference time. In this work, we propose to speed up Speech-LLaMA inference by predicting multiple tokens in the same decoding step. We explore several model architectures that enable this, and investigate their performance using threshold-based and verification-based inference strategies. We also propose a prefix-based beam search decoding method that allows efficient minimum word error rate (MWER) training for such models. We evaluate our models on a variety of public benchmarks, where they reduce the number of decoder calls by ~3.2x while maintaining or improving WER performance.

Faster Speech-LLaMA Inference with Multi-token Prediction

TL;DR

This work proposes to speed up Speech-LLaMA inference by predicting multiple tokens in the same decoding step by proposing a prefix-based beam search decoding method that allows efficient minimum word error rate (MWER) training for such models.

Abstract

Large language models (LLMs) have become proficient at solving a wide variety of tasks, including those involving multi-modal inputs. In particular, instantiating an LLM (such as LLaMA) with a speech encoder and training it on paired data imparts speech recognition (ASR) abilities to the decoder-only model, hence called Speech-LLaMA. Nevertheless, due to the sequential nature of auto-regressive inference and the relatively large decoder, Speech-LLaMA models require relatively high inference time. In this work, we propose to speed up Speech-LLaMA inference by predicting multiple tokens in the same decoding step. We explore several model architectures that enable this, and investigate their performance using threshold-based and verification-based inference strategies. We also propose a prefix-based beam search decoding method that allows efficient minimum word error rate (MWER) training for such models. We evaluate our models on a variety of public benchmarks, where they reduce the number of decoder calls by ~3.2x while maintaining or improving WER performance.
Paper Structure (11 sections, 12 equations, 5 figures, 3 tables)

This paper contains 11 sections, 12 equations, 5 figures, 3 tables.

Figures (5)

  • Figure 1: Overview of decoder-only ASR, e.g., Speech-LLaMA.
  • Figure 2: Architectures for multi-token prediction: (a) independent projection heads, and (b) latent-space expansion. The green-shaded blocks denote parameters initialized from the pre-trained LLM.
  • Figure 3: Dataset statistics for large-scale multi-lingual training. The colors denote Common Voice, VoxPopuli, and MLS, respectively. Each VP subset is sampled in 2% of the batches.
  • Figure 4: Plot of WER (%) versus $\eta$ on LibriSpeech dev-other for different numbers of latent heads and inference methods. Each color represents a different model, and marker styles denote different decoding methods. Lighter markers denote weaker acceptance conditions, i.e., higher $M$ for top-$M$ and lower $\theta$ for threshold-based decoding. The single-head model has $\eta=1.26$.
  • Figure 5: Comparison of number of decoder calls ($\eta$) for different languages, averaged over CV, VP, and MLS.