Table of Contents
Fetching ...

WPN: An Unlearning Method Based on N-pair Contrastive Learning in Language Models

Guitao Chen, Yunshen Wang, Hongye Sun, Guang Chen

TL;DR

This work tackles the problem of harmful outputs in pretrained language models by introducing Weighted Positional N-pair (WPN) Learning, a targeted unlearning approach that does not retrain the entire model. It combines position-weighted mean pooling with an N-pair contrastive loss to reshape the output distribution toward harmless responses while preserving general capabilities, avoiding the degradation typically caused by gradient-ascent methods. Extensive experiments on OPT and GPT-NEO show WPN achieves high harmlessness rates (up to 95.8% on some settings) with minimal degradation on nine NLP benchmarks, and demonstrates improved generalizability and robustness to adversarial prompts. The results suggest WPN as a practical, plug-in unlearning technique for reducing harmful content without sacrificing broad language modeling performance, with favorable time-cost characteristics for larger models.

Abstract

Generative language models (LMs) offer numerous advantages but may produce inappropriate or harmful outputs due to the harmful knowledge acquired during pre-training. This knowledge often manifests as undesirable correspondences, such as "harmful prompts" leading to "harmful outputs," which our research aims to mitigate through unlearning techniques.However, existing unlearning methods based on gradient ascent can significantly impair the performance of LMs. To address this issue, we propose a novel approach called Weighted Positional N-pair (WPN) Learning, which leverages position-weighted mean pooling within an n-pair contrastive learning framework. WPN is designed to modify the output distribution of LMs by eliminating specific harmful outputs (e.g., replacing toxic responses with neutral ones), thereby transforming the model's behavior from "harmful prompt-harmful output" to "harmful prompt-harmless response".Experiments on OPT and GPT-NEO LMs show that WPN effectively reduces the proportion of harmful responses, achieving a harmless rate of up to 95.8\% while maintaining stable performance on nine common benchmarks (with less than 2\% degradation on average). Moreover, we provide empirical evidence to demonstrate WPN's ability to weaken the harmful correspondences in terms of generalizability and robustness, as evaluated on out-of-distribution test sets and under adversarial attacks.

WPN: An Unlearning Method Based on N-pair Contrastive Learning in Language Models

TL;DR

This work tackles the problem of harmful outputs in pretrained language models by introducing Weighted Positional N-pair (WPN) Learning, a targeted unlearning approach that does not retrain the entire model. It combines position-weighted mean pooling with an N-pair contrastive loss to reshape the output distribution toward harmless responses while preserving general capabilities, avoiding the degradation typically caused by gradient-ascent methods. Extensive experiments on OPT and GPT-NEO show WPN achieves high harmlessness rates (up to 95.8% on some settings) with minimal degradation on nine NLP benchmarks, and demonstrates improved generalizability and robustness to adversarial prompts. The results suggest WPN as a practical, plug-in unlearning technique for reducing harmful content without sacrificing broad language modeling performance, with favorable time-cost characteristics for larger models.

Abstract

Generative language models (LMs) offer numerous advantages but may produce inappropriate or harmful outputs due to the harmful knowledge acquired during pre-training. This knowledge often manifests as undesirable correspondences, such as "harmful prompts" leading to "harmful outputs," which our research aims to mitigate through unlearning techniques.However, existing unlearning methods based on gradient ascent can significantly impair the performance of LMs. To address this issue, we propose a novel approach called Weighted Positional N-pair (WPN) Learning, which leverages position-weighted mean pooling within an n-pair contrastive learning framework. WPN is designed to modify the output distribution of LMs by eliminating specific harmful outputs (e.g., replacing toxic responses with neutral ones), thereby transforming the model's behavior from "harmful prompt-harmful output" to "harmful prompt-harmless response".Experiments on OPT and GPT-NEO LMs show that WPN effectively reduces the proportion of harmful responses, achieving a harmless rate of up to 95.8\% while maintaining stable performance on nine common benchmarks (with less than 2\% degradation on average). Moreover, we provide empirical evidence to demonstrate WPN's ability to weaken the harmful correspondences in terms of generalizability and robustness, as evaluated on out-of-distribution test sets and under adversarial attacks.
Paper Structure (27 sections, 7 equations, 4 figures, 4 tables)

This paper contains 27 sections, 7 equations, 4 figures, 4 tables.

Figures (4)

  • Figure 1: The average PPL of responses from different LMs on 1500 samples. After executing WPN, the PPL of LMs can be maintained at a low level.
  • Figure 2: The results of N-pair(Equation (\ref{['npair-loss']})) loss using three pooling methods for six LMs. $PH1$, $PH2$, and $PH3$ respectively represent the harmless response rates on validation sets $\mathcal{D}_{dev1}$, $\mathcal{D}_{dev2}$, and $\mathcal{D}_{dev3}$ after LMs execute the WPN method. $AVG.$ represents the average value on nine NLP benchmarks. $PA1=\alpha PH_{1}+\beta A_{avg}$ and $PA2=\alpha PH_{3}+\beta A_{avg}$ respectively represent the comprehensive performance and generalization performance of the unlearning algorithm, where $\alpha=0.2$, $\beta=0.8$.
  • Figure 3: Comparison of the execution times between two unlearning algorithms. The experiment was conducted with a total of 500 data points trained over 3 epochs.
  • Figure 4: Text generation example. When harmful questions pass through unlearned LM, the text representation before decoding tends to lean towards the harmless text space.