Table of Contents
Fetching ...

In-context Learning in Presence of Spurious Correlations

Hrayr Harutyunyan, Rafayel Darbinyan, Samvel Karapetyan, Hrant Khachatrian

TL;DR

It is shown that it is possible to obtain an in-context learner that generalizes to unseen tasks by training on a diverse dataset of synthetic in-context learning instances.

Abstract

Large language models exhibit a remarkable capacity for in-context learning, where they learn to solve tasks given a few examples. Recent work has shown that transformers can be trained to perform simple regression tasks in-context. This work explores the possibility of training an in-context learner for classification tasks involving spurious features. We find that the conventional approach of training in-context learners is susceptible to spurious features. Moreover, when the meta-training dataset includes instances of only one task, the conventional approach leads to task memorization and fails to produce a model that leverages context for predictions. Based on these observations, we propose a novel technique to train such a learner for a given classification task. Remarkably, this in-context learner matches and sometimes outperforms strong methods like ERM and GroupDRO. However, unlike these algorithms, it does not generalize well to other tasks. We show that it is possible to obtain an in-context learner that generalizes to unseen tasks by training on a diverse dataset of synthetic in-context learning instances.

In-context Learning in Presence of Spurious Correlations

TL;DR

It is shown that it is possible to obtain an in-context learner that generalizes to unseen tasks by training on a diverse dataset of synthetic in-context learning instances.

Abstract

Large language models exhibit a remarkable capacity for in-context learning, where they learn to solve tasks given a few examples. Recent work has shown that transformers can be trained to perform simple regression tasks in-context. This work explores the possibility of training an in-context learner for classification tasks involving spurious features. We find that the conventional approach of training in-context learners is susceptible to spurious features. Moreover, when the meta-training dataset includes instances of only one task, the conventional approach leads to task memorization and fails to produce a model that leverages context for predictions. Based on these observations, we propose a novel technique to train such a learner for a given classification task. Remarkably, this in-context learner matches and sometimes outperforms strong methods like ERM and GroupDRO. However, unlike these algorithms, it does not generalize well to other tasks. We show that it is possible to obtain an in-context learner that generalizes to unseen tasks by training on a diverse dataset of synthetic in-context learning instances.
Paper Structure (24 sections, 2 equations, 14 figures, 4 tables)

This paper contains 24 sections, 2 equations, 14 figures, 4 tables.

Figures (14)

  • Figure 1: In-context learning transformer architectures of the naive and proposed approaches. The proposed approach allows arbitrary query tokens after each learning example. Token positions and attention mask are modified so that these intermediate queries have no effect on other tokens.
  • Figure 2: Majority-group and worst-group test accuracies on Waterbirds as a function of context size for the naive and proposed approaches with or without permuting input dimensions. Shaded regions show standard deviation across 5 training runs.
  • Figure 3: Majority-group and worst-group test accuracies on Waterbirds-severe as a function of context size for the naive and proposed approaches with or without permuting input dimensions. Shaded regions show standard deviation across 5 training runs.
  • Figure 4: Worst-group test accuracies on CelebA for the proposed approach and conventional methods such as 1-NN, ERM, and GroupDRO. Shaded regions show standard deviation across 5 training runs. The "+G" marker indicates setting $\tilde{y}_i$ to represent $g_i$, i.e., passing context examples groups as input. Please refer to \ref{['sec:multiple-tasks']} for more information on this.
  • Figure 5: Worst-group test accuracies on Waterbirds and Waterbirds-severe for the proposed approach and conventional methods such as 1-NN, ERM, and GroupDRO. Majority-group accuracies are reported in \ref{['fig:wb-and-mod-wb-maj-acc-baseline-comparison']} of \ref{['app:more-results']}. Shaded regions show standard deviation across 5 training runs.
  • ...and 9 more figures