Table of Contents
Fetching ...

Adversarial Reinforcement Learning for Large Language Model Agent Safety

Zizhao Wang, Dingcheng Li, Vaishakh Keshava, Phillip Wallis, Ananth Balashankar, Peter Stone, Lukas Rutishauser

TL;DR

This work proposes Adversarial Reinforcement Learning for Agent Safety (ARLAS), a novel framework that leverages adversarial reinforcement learning (RL) by formulating the problem as a two-player zero-sum game and confirms that the adversarial process generates a diverse and challenging set of attacks, leading to a more robust agent compared to the base model.

Abstract

Large Language Model (LLM) agents can leverage tools such as Google Search to complete complex tasks. However, this tool usage introduces the risk of indirect prompt injections, where malicious instructions hidden in tool outputs can manipulate the agent, posing security risks like data leakage. Current defense strategies typically rely on fine-tuning LLM agents on datasets of known attacks. However, the generation of these datasets relies on manually crafted attack patterns, which limits their diversity and leaves agents vulnerable to novel prompt injections. To address this limitation, we propose Adversarial Reinforcement Learning for Agent Safety (ARLAS), a novel framework that leverages adversarial reinforcement learning (RL) by formulating the problem as a two-player zero-sum game. ARLAS co-trains two LLMs: an attacker that learns to autonomously generate diverse prompt injections and an agent that learns to defend against them while completing its assigned tasks. To ensure robustness against a wide range of attacks and to prevent cyclic learning, we employ a population-based learning framework that trains the agent to defend against all previous attacker checkpoints. Evaluated on BrowserGym and AgentDojo, agents fine-tuned with ARLAS achieve a significantly lower attack success rate than the original model while also improving their task success rate. Our analysis further confirms that the adversarial process generates a diverse and challenging set of attacks, leading to a more robust agent compared to the base model.

Adversarial Reinforcement Learning for Large Language Model Agent Safety

TL;DR

This work proposes Adversarial Reinforcement Learning for Agent Safety (ARLAS), a novel framework that leverages adversarial reinforcement learning (RL) by formulating the problem as a two-player zero-sum game and confirms that the adversarial process generates a diverse and challenging set of attacks, leading to a more robust agent compared to the base model.

Abstract

Large Language Model (LLM) agents can leverage tools such as Google Search to complete complex tasks. However, this tool usage introduces the risk of indirect prompt injections, where malicious instructions hidden in tool outputs can manipulate the agent, posing security risks like data leakage. Current defense strategies typically rely on fine-tuning LLM agents on datasets of known attacks. However, the generation of these datasets relies on manually crafted attack patterns, which limits their diversity and leaves agents vulnerable to novel prompt injections. To address this limitation, we propose Adversarial Reinforcement Learning for Agent Safety (ARLAS), a novel framework that leverages adversarial reinforcement learning (RL) by formulating the problem as a two-player zero-sum game. ARLAS co-trains two LLMs: an attacker that learns to autonomously generate diverse prompt injections and an agent that learns to defend against them while completing its assigned tasks. To ensure robustness against a wide range of attacks and to prevent cyclic learning, we employ a population-based learning framework that trains the agent to defend against all previous attacker checkpoints. Evaluated on BrowserGym and AgentDojo, agents fine-tuned with ARLAS achieve a significantly lower attack success rate than the original model while also improving their task success rate. Our analysis further confirms that the adversarial process generates a diverse and challenging set of attacks, leading to a more robust agent compared to the base model.

Paper Structure

This paper contains 24 sections, 6 equations, 7 figures, 3 tables, 1 algorithm.

Figures (7)

  • Figure 1: ARLAS enhances LLM agent safety via a jointly trained attacker. In each turn of an episode, the attacker first generates an indirect prompt injection to insert into the observation, and then the agent selects an action (i.e., which tool to call and its parameters). The agent and the attacker receive sparse rewards at the end of the episode, based on whether the attacker tricks the agent into leaking user information and whether the agent successfully completes the task. ARLAS trains both models to maximize their respective rewards using RL.
  • Figure 2: Compare to (left) iterative training that could lead to cyclic learning, (right) ARLAS leverages population-based learning, training the agent model to be robust against all previous attacker models.
  • Figure 3: ARLAS performance on unseen BrowserGym tasks, measured as the mean and standard error across 3 random seeds. Each heat map shows how the agent at different learning stages performs against the attacker at different stages, where the top row measures the performance when there is no attack.
  • Figure 4: All methods' performance on unseen BrowserGym tasks, measured as the mean and standard error across 3 random seeds. Each heat map shows how the agent from each method performs against the attacker from each method, where the top row measures the performance when there is no attack and the bottom row shows the average performance of each agent playing against all attackers. In each heat map, the agents are ordered based on average performance, from the worst to the best.
  • Figure 5: (Left) UMAP projection of attacks generated by ARLAS at different learning stages. (Right) Average pairwise distance across all tasks at different ARLAS learning stages.
  • ...and 2 more figures