Table of Contents
Fetching ...

Improving Large Language Model Safety with Contrastive Representation Learning

Samuel Simko, Mrinmaya Sachan, Bernhard Schölkopf, Zhijing Jin

TL;DR

This work tackles the safety challenges of large language models under jailbreak and adversarial prompts by presenting a contrastive representation learning defense. By formulating safety in a learned representation space and optimizing a triplet-based objective with adversarial hard negative mining, the method enforces similarity between benign representations and dissimilarity from harmful ones, while preserving benign behavior and KL-alignment on safe prompts. Empirical results show the triplet defense outperforms circuit breakers and RepBend across input- and embedding-space attacks, with embedding-space ASR driven down to as low as 0% (and 4.88% with adversarial mining) on Llama 3 8B, and general-language performance retained on standard benchmarks. The approach generalizes to out-of-distribution formats (measured by MMDR) and remains effective across multiple models, albeit with compute costs and some model-specific limitations, suggesting practical value for deploying safer LLMs in diverse settings.

Abstract

Large Language Models (LLMs) are powerful tools with profound societal impacts, yet their ability to generate responses to diverse and uncontrolled inputs leaves them vulnerable to adversarial attacks. While existing defenses often struggle to generalize across varying attack types, recent advancements in representation engineering offer promising alternatives. In this work, we propose a defense framework that formulates model defense as a contrastive representation learning (CRL) problem. Our method finetunes a model using a triplet-based loss combined with adversarial hard negative mining to encourage separation between benign and harmful representations. Our experimental results across multiple models demonstrate that our approach outperforms prior representation engineering-based defenses, improving robustness against both input-level and embedding-space attacks without compromising standard performance. Our code is available at https://github.com/samuelsimko/crl-llm-defense

Improving Large Language Model Safety with Contrastive Representation Learning

TL;DR

This work tackles the safety challenges of large language models under jailbreak and adversarial prompts by presenting a contrastive representation learning defense. By formulating safety in a learned representation space and optimizing a triplet-based objective with adversarial hard negative mining, the method enforces similarity between benign representations and dissimilarity from harmful ones, while preserving benign behavior and KL-alignment on safe prompts. Empirical results show the triplet defense outperforms circuit breakers and RepBend across input- and embedding-space attacks, with embedding-space ASR driven down to as low as 0% (and 4.88% with adversarial mining) on Llama 3 8B, and general-language performance retained on standard benchmarks. The approach generalizes to out-of-distribution formats (measured by MMDR) and remains effective across multiple models, albeit with compute costs and some model-specific limitations, suggesting practical value for deploying safer LLMs in diverse settings.

Abstract

Large Language Models (LLMs) are powerful tools with profound societal impacts, yet their ability to generate responses to diverse and uncontrolled inputs leaves them vulnerable to adversarial attacks. While existing defenses often struggle to generalize across varying attack types, recent advancements in representation engineering offer promising alternatives. In this work, we propose a defense framework that formulates model defense as a contrastive representation learning (CRL) problem. Our method finetunes a model using a triplet-based loss combined with adversarial hard negative mining to encourage separation between benign and harmful representations. Our experimental results across multiple models demonstrate that our approach outperforms prior representation engineering-based defenses, improving robustness against both input-level and embedding-space attacks without compromising standard performance. Our code is available at https://github.com/samuelsimko/crl-llm-defense

Paper Structure

This paper contains 66 sections, 2 theorems, 16 equations, 5 figures, 21 tables.

Key Result

Theorem 1

The circuit breakers loss $\mathcal{L}_{CB}$ can be rewritten as a triplet loss $\mathcal{L}_{triplet}$ with null distances $d_0(x, y) = 0$.

Figures (5)

  • Figure 1: Comparison of the Triplet defense with the Circuit Breaking defense. Contrary to other adversarial defense methods, circuit breaking aims to break generation at harmful content instead of refusing to answer harmful tasks. It fine-tunes models to keep learned harmless states (or representations) close together while separating newly learned harmful states from their original counterparts, without additional constraints. In contrast, the Triplet defense additionally pulls learned harmful states together and pushes them away from learned harmless states, which increases contrast and robustness to embedding-space attacks.
  • Figure 2: Embedding Attack success rate (ASR) using StrongREJECT for various defenses (Llama 3 8B Instruct.
  • Figure 3: Triplet loss objective before and after a learning step. The anchor (blue) is kept at the same position, while the positive (green) is moved closer to the anchor, and the negative (red) is moved further away from the anchor.
  • Figure 4: Examples of prompts and responses for input-space and embedding-space attacks in existing adversarial training-based defenses (left) and our triplet-based defense (right).
  • Figure 5: t-SNE visualization of layer 25 representations of Llama 3 8B. Representations from benign prompts (green), harmful prompts (red), and embedding-attacked harmful prompts (brown) are shown. Our Triplet-based defense achieves tighter clustering of harmful and attacked representations compared to baseline defenses.

Theorems & Definitions (4)

  • Theorem 1
  • proof
  • Theorem 2
  • proof