Table of Contents
Fetching ...

SALAD: Improving Robustness and Generalization through Contrastive Learning with Structure-Aware and LLM-Driven Augmented Data

Suyoung Bae, Hyojun Kim, YunSeok Choi, Jee-Hyong Lee

TL;DR

This paper tackles spurious correlations in NLP by introducing SALAD, a framework that jointly learns from structure-aware positive samples and LLM-generated counterfactual negatives to improve robustness and generalization. SALAD constructs positives by masking non-causal structural tokens using POS tagging and generates diverse negatives with an LLM guided by causal word information, then trains with a triplet loss alongside standard supervision. Empirical results across sentiment classification, sexism detection, and natural language inference show SALAD yields strong robustness, notable improvements in out-of-distribution and cross-domain settings, and competitive in-domain performance relative to strong baselines. The approach demonstrates practical impact by reducing shortcut reliance and enhancing generalization, with GPT-based CAD data quality closely matching human-generated data, highlighting SALAD’s potential for scalable, robust NLP systems.

Abstract

In various natural language processing (NLP) tasks, fine-tuning Pre-trained Language Models (PLMs) often leads to the issue of spurious correlations, which negatively impacts performance, particularly when dealing with out-of-distribution data. To address this problem, we propose SALAD}(Structure Aware and LLM-driven Augmented Data), a novel approach designed to enhance model robustness and generalization by generating structure-aware and counterfactually augmented data for contrastive learning. Our method leverages a tagging-based approach to generate structure-aware positive samples and utilizes large language models (LLMs) to generate counterfactual negative samples with diverse sentence patterns. By applying contrastive learning, SALAD enables the model to focus on learning the structural relationships between key sentence components while minimizing reliance on spurious correlations. We validate our approach through experiments on three tasks: Sentiment Classification, Sexism Detection, and Natural Language Inference. The results demonstrate that SALAD not only improves model robustness and performance across different environments but also enhances generalization to out-of-distribution datasets and cross-domain scenarios.

SALAD: Improving Robustness and Generalization through Contrastive Learning with Structure-Aware and LLM-Driven Augmented Data

TL;DR

This paper tackles spurious correlations in NLP by introducing SALAD, a framework that jointly learns from structure-aware positive samples and LLM-generated counterfactual negatives to improve robustness and generalization. SALAD constructs positives by masking non-causal structural tokens using POS tagging and generates diverse negatives with an LLM guided by causal word information, then trains with a triplet loss alongside standard supervision. Empirical results across sentiment classification, sexism detection, and natural language inference show SALAD yields strong robustness, notable improvements in out-of-distribution and cross-domain settings, and competitive in-domain performance relative to strong baselines. The approach demonstrates practical impact by reducing shortcut reliance and enhancing generalization, with GPT-based CAD data quality closely matching human-generated data, highlighting SALAD’s potential for scalable, robust NLP systems.

Abstract

In various natural language processing (NLP) tasks, fine-tuning Pre-trained Language Models (PLMs) often leads to the issue of spurious correlations, which negatively impacts performance, particularly when dealing with out-of-distribution data. To address this problem, we propose SALAD}(Structure Aware and LLM-driven Augmented Data), a novel approach designed to enhance model robustness and generalization by generating structure-aware and counterfactually augmented data for contrastive learning. Our method leverages a tagging-based approach to generate structure-aware positive samples and utilizes large language models (LLMs) to generate counterfactual negative samples with diverse sentence patterns. By applying contrastive learning, SALAD enables the model to focus on learning the structural relationships between key sentence components while minimizing reliance on spurious correlations. We validate our approach through experiments on three tasks: Sentiment Classification, Sexism Detection, and Natural Language Inference. The results demonstrate that SALAD not only improves model robustness and performance across different environments but also enhances generalization to out-of-distribution datasets and cross-domain scenarios.

Paper Structure

This paper contains 40 sections, 5 equations, 4 figures, 13 tables.

Figures (4)

  • Figure 1: Overview of SALAD. Our proposed method consists of three steps. First, we use a tagging-based method to generate positive data based on the structure where shortcuts occur (Sec. \ref{['3.1']}). Next, we use an LLM to generate counterfactual data to capture complex and diverse sentence patterns (Sec. \ref{['3.2']}). Finally, contrastive learning is applied to effectively capture key sentence structural patterns between our augmented data, minimizing spurious correlations and enhancing generalization performance (Sec. \ref{['3.3']}).
  • Figure 2: Experiments on defining $k$: The value of 8 shows significant performance improvement for the CF-IMDB dataset, particularly on the out-of-distribution dataset (ODD).
  • Figure 3: Accuracy reduction of each POS category across datasets: The $x$-axis represents each POS category, and the $y$-axis represents the average accuracy reduction. We define POS tags with an average accuracy reduction of less than 1% as the non-causal tag set $G$.
  • Figure 4: Performance variations of SALAD on datasets generated for each instruction. The number following "SALAD" corresponds to the instructions associated with each number used in Table \ref{['table:instructions']}.