Table of Contents
Fetching ...

MrT5: Dynamic Token Merging for Efficient Byte-level Language Models

Julie Kallini, Shikhar Murty, Christopher D. Manning, Christopher Potts, Róbert Csordás

TL;DR

MrT5 is introduced, a more efficient variant of ByT5 that integrates a token deletion mechanism in its encoder to dynamically shorten the input sequence length, and presents a solution to the practical limitations of existing byte-level models.

Abstract

Models that rely on subword tokenization have significant drawbacks, such as sensitivity to character-level noise like spelling errors and inconsistent compression rates across different languages and scripts. While character- or byte-level models like ByT5 attempt to address these concerns, they have not gained widespread adoption -- processing raw byte streams without tokenization results in significantly longer sequence lengths, making training and inference inefficient. This work introduces MrT5 (MergeT5), a more efficient variant of ByT5 that integrates a token deletion mechanism in its encoder to dynamically shorten the input sequence length. After processing through a fixed number of encoder layers, a learned delete gate determines which tokens are to be removed and which are to be retained for subsequent layers. MrT5 effectively "merges" critical information from deleted tokens into a more compact sequence, leveraging contextual information from the remaining tokens. In continued pre-training experiments, we find that MrT5 can achieve significant gains in inference runtime with minimal effect on performance, as measured by bits-per-byte. Additionally, with multilingual training, MrT5 adapts to the orthographic characteristics of each language, learning language-specific compression rates. Furthermore, MrT5 shows comparable accuracy to ByT5 on downstream evaluations such as XNLI, TyDi QA, and character-level tasks while reducing sequence lengths by up to 75%. Our approach presents a solution to the practical limitations of existing byte-level models.

MrT5: Dynamic Token Merging for Efficient Byte-level Language Models

TL;DR

MrT5 is introduced, a more efficient variant of ByT5 that integrates a token deletion mechanism in its encoder to dynamically shorten the input sequence length, and presents a solution to the practical limitations of existing byte-level models.

Abstract

Models that rely on subword tokenization have significant drawbacks, such as sensitivity to character-level noise like spelling errors and inconsistent compression rates across different languages and scripts. While character- or byte-level models like ByT5 attempt to address these concerns, they have not gained widespread adoption -- processing raw byte streams without tokenization results in significantly longer sequence lengths, making training and inference inefficient. This work introduces MrT5 (MergeT5), a more efficient variant of ByT5 that integrates a token deletion mechanism in its encoder to dynamically shorten the input sequence length. After processing through a fixed number of encoder layers, a learned delete gate determines which tokens are to be removed and which are to be retained for subsequent layers. MrT5 effectively "merges" critical information from deleted tokens into a more compact sequence, leveraging contextual information from the remaining tokens. In continued pre-training experiments, we find that MrT5 can achieve significant gains in inference runtime with minimal effect on performance, as measured by bits-per-byte. Additionally, with multilingual training, MrT5 adapts to the orthographic characteristics of each language, learning language-specific compression rates. Furthermore, MrT5 shows comparable accuracy to ByT5 on downstream evaluations such as XNLI, TyDi QA, and character-level tasks while reducing sequence lengths by up to 75%. Our approach presents a solution to the practical limitations of existing byte-level models.

Paper Structure

This paper contains 55 sections, 16 equations, 7 figures, 11 tables.

Figures (7)

  • Figure 1: MrT5's encoder during training and testing. During training, fully-differentiable soft deletion masks out tokens using the output of MrT5's delete gate. During testing, hard deletion removes columns from the computation, which reduces the sequence length and leads to efficiency gains. In this visual, the delete gate is placed at layer 2, but the gate placement may be tuned.
  • Figure 2: Span corruption BPB vs. sequence length reduction for each MrT5 and baseline model. MrT5 models consistently have much lower BPB than the baselines, and are generally competitive with unmodified ByT5, even where they achieve very large sequence length reductions.
  • Figure 3: Average test set BPB vs. sequence length reduction for MrT5 ($\delta=0.5$) across each of the 15 languages. ByT5 is shown for BPB comparison only (it does not reduce the sequence length). MrT5 learns language-specific sequence length reduction rates and achieves over 50% reduction in many languages with minimal effect on the BPB.
  • Figure 4: BPB and inference runtime for a single sequence for MrT5 models with delete gates at different layers ($l \in [ 1, 5]$). All MrT5 models are trained with a PI controller with a target deletion ratio of $\delta = 0.5$. BPB is consistently higher in early layers; since the gate should be placed as early as possible, we select layer 3 as optimal.
  • Figure 5: Reduction in the total amount of compute as a function of the deletion ratio $\delta$.
  • ...and 2 more figures