Token-Weighted RNN-T for Learning from Flawed Data
Gil Keren, Wei Zhou, Ozlem Kalinli
TL;DR
This work addresses the limitation of standard RNN-T training, which treats all target tokens equally and is vulnerable to transcription errors. It introduces a token-weighted RNN-T objective $L_w = -\sum_{u=1}^{U} \lambda_u \log P(y_u | y_{<u})$ and derives an exact, efficient computation of $P(y_u | y_{<u})$ in terms of aggregated alignment paths, enabling token-level weighting. Token weights $\lambda_u$ are derived from teacher-confidence scores via $c_u = P_{teacher}(y_u | y_{<u})$ and $\lambda_u = c_u^{\alpha} / (\frac{1}{U'} \sum_{u'} c_{u'}^{\alpha})$, with batch normalization, addressing both pseudo-labeling noise and human annotation errors. Empirically, the approach yields up to 38% relative WER improvement in SSL on Librispeech and robustly mitigates degradation from corrupted references, recovering 64–99% of the accuracy loss, illustrating practical impact for real-world ASR systems using noisy training data.
Abstract
ASR models are commonly trained with the cross-entropy criterion to increase the probability of a target token sequence. While optimizing the probability of all tokens in the target sequence is sensible, one may want to de-emphasize tokens that reflect transcription errors. In this work, we propose a novel token-weighted RNN-T criterion that augments the RNN-T objective with token-specific weights. The new objective is used for mitigating accuracy loss from transcriptions errors in the training data, which naturally appear in two settings: pseudo-labeling and human annotation errors. Experiments results show that using our method for semi-supervised learning with pseudo-labels leads to a consistent accuracy improvement, up to 38% relative. We also analyze the accuracy degradation resulting from different levels of WER in the reference transcription, and show that token-weighted RNN-T is suitable for overcoming this degradation, recovering 64%-99% of the accuracy loss.
