Table of Contents
Fetching ...

Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding

Ukyo Honda, Tatsushi Oka, Peinan Zhang, Masato Mita

TL;DR

This work pessimistically aggregates the predictions of a mixture-of-experts, assuming each expert captures relatively different latent features, assuming each expert captures relatively different latent features under distribution shifts in shortcuts.

Abstract

Recent models for natural language understanding are inclined to exploit simple patterns in datasets, commonly known as shortcuts. These shortcuts hinge on spurious correlations between labels and latent features existing in the training data. At inference time, shortcut-dependent models are likely to generate erroneous predictions under distribution shifts, particularly when some latent features are no longer correlated with the labels. To avoid this, previous studies have trained models to eliminate the reliance on shortcuts. In this study, we explore a different direction: pessimistically aggregating the predictions of a mixture-of-experts, assuming each expert captures relatively different latent features. The experimental results demonstrate that our post-hoc control over the experts significantly enhances the model's robustness to the distribution shift in shortcuts. Besides, we show that our approach has some practical advantages. We also analyze our model and provide results to support the assumption.

Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding

TL;DR

This work pessimistically aggregates the predictions of a mixture-of-experts, assuming each expert captures relatively different latent features, assuming each expert captures relatively different latent features under distribution shifts in shortcuts.

Abstract

Recent models for natural language understanding are inclined to exploit simple patterns in datasets, commonly known as shortcuts. These shortcuts hinge on spurious correlations between labels and latent features existing in the training data. At inference time, shortcut-dependent models are likely to generate erroneous predictions under distribution shifts, particularly when some latent features are no longer correlated with the labels. To avoid this, previous studies have trained models to eliminate the reliance on shortcuts. In this study, we explore a different direction: pessimistically aggregating the predictions of a mixture-of-experts, assuming each expert captures relatively different latent features. The experimental results demonstrate that our post-hoc control over the experts significantly enhances the model's robustness to the distribution shift in shortcuts. Besides, we show that our approach has some practical advantages. We also analyze our model and provide results to support the assumption.
Paper Structure (35 sections, 16 equations, 5 figures, 5 tables, 1 algorithm)

This paper contains 35 sections, 16 equations, 5 figures, 5 tables, 1 algorithm.

Figures (5)

  • Figure 1: An illustrative example of shortcuts in the task of natural language inference. P and H denote the premise and hypothesis sentence, respectively. $\{a_i\}$ are latent features related to $x$. The value on the right-hand side of $\hat{y}$ shows the confidence $\in [0.0, 1.0]$ of the prediction. $a_i$ is correctly predictive of label $y$ in the training and ID data but not in the OOD data where the association between $a_i$ and $y$ changed. $a^*$ is an ideal latent feature predictive of $y$ across distributions. However, such $a^*$ is generally difficult for models to rely on. This figure illustrates the common case where predictions are not based on $a^*$.
  • Figure 2: The mixture weights averaged on each split of the datasets. Each split, excluding the ID dev (Dev), has its own dominant feature. For FEVER, Dev$^{\emph{label}}_{\text{bigram}}$ consists of the instances in ID dev that contain the bigram reported to strongly correlate with the label schuster-etal-2019-towards. Here, no post-hoc control is performed on the mixture weights.
  • Figure 3: Overview of our method. We fit training data using a mixture model consisting of $K$ expert networks $\{E^k\}_{k=1}^K$ and a router network $\pi$ (Section \ref{['sec:mixture model']}). During inference, the model is used as is for ID data, and $\pi$ is replaced with $\pi^*$ for OOD data (Section \ref{['sec:control']}).
  • Figure 4: An example of decision-making with argmin weighting, where $K=2$ and $\mathcal{|Y|}=3$. After performing argmin weighting, label 2 achieves the highest score (starred) and is thus chosen as the answer.
  • Figure 5: The average prediction of each expert $E^k$ across MNLI ID dev.