SALAD: Improving Robustness and Generalization through Contrastive Learning with Structure-Aware and LLM-Driven Augmented Data
Suyoung Bae, Hyojun Kim, YunSeok Choi, Jee-Hyong Lee
TL;DR
This paper tackles spurious correlations in NLP by introducing SALAD, a framework that jointly learns from structure-aware positive samples and LLM-generated counterfactual negatives to improve robustness and generalization. SALAD constructs positives by masking non-causal structural tokens using POS tagging and generates diverse negatives with an LLM guided by causal word information, then trains with a triplet loss alongside standard supervision. Empirical results across sentiment classification, sexism detection, and natural language inference show SALAD yields strong robustness, notable improvements in out-of-distribution and cross-domain settings, and competitive in-domain performance relative to strong baselines. The approach demonstrates practical impact by reducing shortcut reliance and enhancing generalization, with GPT-based CAD data quality closely matching human-generated data, highlighting SALAD’s potential for scalable, robust NLP systems.
Abstract
In various natural language processing (NLP) tasks, fine-tuning Pre-trained Language Models (PLMs) often leads to the issue of spurious correlations, which negatively impacts performance, particularly when dealing with out-of-distribution data. To address this problem, we propose SALAD}(Structure Aware and LLM-driven Augmented Data), a novel approach designed to enhance model robustness and generalization by generating structure-aware and counterfactually augmented data for contrastive learning. Our method leverages a tagging-based approach to generate structure-aware positive samples and utilizes large language models (LLMs) to generate counterfactual negative samples with diverse sentence patterns. By applying contrastive learning, SALAD enables the model to focus on learning the structural relationships between key sentence components while minimizing reliance on spurious correlations. We validate our approach through experiments on three tasks: Sentiment Classification, Sexism Detection, and Natural Language Inference. The results demonstrate that SALAD not only improves model robustness and performance across different environments but also enhances generalization to out-of-distribution datasets and cross-domain scenarios.
