Table of Contents
Fetching ...

Token-Weighted RNN-T for Learning from Flawed Data

Gil Keren, Wei Zhou, Ozlem Kalinli

TL;DR

This work addresses the limitation of standard RNN-T training, which treats all target tokens equally and is vulnerable to transcription errors. It introduces a token-weighted RNN-T objective $L_w = -\sum_{u=1}^{U} \lambda_u \log P(y_u | y_{<u})$ and derives an exact, efficient computation of $P(y_u | y_{<u})$ in terms of aggregated alignment paths, enabling token-level weighting. Token weights $\lambda_u$ are derived from teacher-confidence scores via $c_u = P_{teacher}(y_u | y_{<u})$ and $\lambda_u = c_u^{\alpha} / (\frac{1}{U'} \sum_{u'} c_{u'}^{\alpha})$, with batch normalization, addressing both pseudo-labeling noise and human annotation errors. Empirically, the approach yields up to 38% relative WER improvement in SSL on Librispeech and robustly mitigates degradation from corrupted references, recovering 64–99% of the accuracy loss, illustrating practical impact for real-world ASR systems using noisy training data.

Abstract

ASR models are commonly trained with the cross-entropy criterion to increase the probability of a target token sequence. While optimizing the probability of all tokens in the target sequence is sensible, one may want to de-emphasize tokens that reflect transcription errors. In this work, we propose a novel token-weighted RNN-T criterion that augments the RNN-T objective with token-specific weights. The new objective is used for mitigating accuracy loss from transcriptions errors in the training data, which naturally appear in two settings: pseudo-labeling and human annotation errors. Experiments results show that using our method for semi-supervised learning with pseudo-labels leads to a consistent accuracy improvement, up to 38% relative. We also analyze the accuracy degradation resulting from different levels of WER in the reference transcription, and show that token-weighted RNN-T is suitable for overcoming this degradation, recovering 64%-99% of the accuracy loss.

Token-Weighted RNN-T for Learning from Flawed Data

TL;DR

This work addresses the limitation of standard RNN-T training, which treats all target tokens equally and is vulnerable to transcription errors. It introduces a token-weighted RNN-T objective and derives an exact, efficient computation of in terms of aggregated alignment paths, enabling token-level weighting. Token weights are derived from teacher-confidence scores via and , with batch normalization, addressing both pseudo-labeling noise and human annotation errors. Empirically, the approach yields up to 38% relative WER improvement in SSL on Librispeech and robustly mitigates degradation from corrupted references, recovering 64–99% of the accuracy loss, illustrating practical impact for real-world ASR systems using noisy training data.

Abstract

ASR models are commonly trained with the cross-entropy criterion to increase the probability of a target token sequence. While optimizing the probability of all tokens in the target sequence is sensible, one may want to de-emphasize tokens that reflect transcription errors. In this work, we propose a novel token-weighted RNN-T criterion that augments the RNN-T objective with token-specific weights. The new objective is used for mitigating accuracy loss from transcriptions errors in the training data, which naturally appear in two settings: pseudo-labeling and human annotation errors. Experiments results show that using our method for semi-supervised learning with pseudo-labels leads to a consistent accuracy improvement, up to 38% relative. We also analyze the accuracy degradation resulting from different levels of WER in the reference transcription, and show that token-weighted RNN-T is suitable for overcoming this degradation, recovering 64%-99% of the accuracy loss.

Paper Structure

This paper contains 11 sections, 8 equations, 1 figure, 3 tables.

Figures (1)

  • Figure 1: All possible partial alignments from the token 'C' to the token 'A' in four audio frames. Each image corresponds to different time step where the token 'C' may have been emitted. The horizontal arrows correspond to emitting the blank symbol $\phi$. When computing $P(y_u | y_{<u}$) we sum across all the above paths.