Table of Contents
Fetching ...

DAGER: Exact Gradient Inversion for Large Language Models

Ivo Petrov, Dimitar I. Dimitrov, Maximilian Baader, Mark Niklas Müller, Martin Vechev

TL;DR

This work proposes DAGER, the first algorithm to recover whole batches of input text exactly and provides an efficient GPU implementation of DAGER, beating prior attacks in speed, scalability, and reconstruction quality.

Abstract

Federated learning works by aggregating locally computed gradients from multiple clients, thus enabling collaborative training without sharing private client data. However, prior work has shown that the data can actually be recovered by the server using so-called gradient inversion attacks. While these attacks perform well when applied on images, they are limited in the text domain and only permit approximate reconstruction of small batches and short input sequences. In this work, we propose DAGER, the first algorithm to recover whole batches of input text exactly. DAGER leverages the low-rank structure of self-attention layer gradients and the discrete nature of token embeddings to efficiently check if a given token sequence is part of the client data. We use this check to exactly recover full batches in the honest-but-curious setting without any prior on the data for both encoder- and decoder-based architectures using exhaustive heuristic search and a greedy approach, respectively. We provide an efficient GPU implementation of DAGER and show experimentally that it recovers full batches of size up to 128 on large language models (LLMs), beating prior attacks in speed (20x at same batch size), scalability (10x larger batches), and reconstruction quality (ROUGE-1/2 > 0.99).

DAGER: Exact Gradient Inversion for Large Language Models

TL;DR

This work proposes DAGER, the first algorithm to recover whole batches of input text exactly and provides an efficient GPU implementation of DAGER, beating prior attacks in speed, scalability, and reconstruction quality.

Abstract

Federated learning works by aggregating locally computed gradients from multiple clients, thus enabling collaborative training without sharing private client data. However, prior work has shown that the data can actually be recovered by the server using so-called gradient inversion attacks. While these attacks perform well when applied on images, they are limited in the text domain and only permit approximate reconstruction of small batches and short input sequences. In this work, we propose DAGER, the first algorithm to recover whole batches of input text exactly. DAGER leverages the low-rank structure of self-attention layer gradients and the discrete nature of token embeddings to efficiently check if a given token sequence is part of the client data. We use this check to exactly recover full batches in the honest-but-curious setting without any prior on the data for both encoder- and decoder-based architectures using exhaustive heuristic search and a greedy approach, respectively. We provide an efficient GPU implementation of DAGER and show experimentally that it recovers full batches of size up to 128 on large language models (LLMs), beating prior attacks in speed (20x at same batch size), scalability (10x larger batches), and reconstruction quality (ROUGE-1/2 > 0.99).
Paper Structure (59 sections, 9 theorems, 10 equations, 7 figures, 11 tables, 3 algorithms)

This paper contains 59 sections, 9 theorems, 10 equations, 7 figures, 11 tables, 3 algorithms.

Key Result

Theorem 3.1

The network's gradient w.r.t. the weights ${\bm{W}}$ can be represented as the matrix product: Further, when the batch size $b \leq n, m$, the rank of $\frac{\partial \mathcal{L}}{\partial {\bm{W}}}$ is at most $b$.

Figures (7)

  • Figure 1: Overview of DAGER. DAGER first recovers the sets of client tokens $\mathcal{T}^*_i$ at each position $i\in\mathcal{P}$ by testing each token in the vocabulary $\mathcal{V}$ via a span check based on the client gradients of the first self-attention. Then it recursively combines them into partial client sequences $\mathcal{S}_i$ with length up to $i$, filtered to obtain the correct sequences $\mathcal{S}^*_i$ via the gradients of the second self-attention.
  • Figure 2: Effect of L1 and L2 Filtering
  • Figure 3: Encoder Ablation Study
  • Figure : Overview of DAGER. DAGER first recovers the sets of client tokens $\mathcal{T}^*_i$ at each position $i\in\mathcal{P}$ by testing each token in the vocabulary $\mathcal{V}$ via a span check based on the client gradients of the first self-attention. Then it recursively combines them into partial client sequences $\mathcal{S}_i$ with length up to $i$, filtered to obtain the correct sequences $\mathcal{S}^*_i$ via the gradients of the second self-attention.
  • Figure : Recovering Individual Tokens
  • ...and 2 more figures

Theorems & Definitions (15)

  • Theorem 3.1: Adapted from spear
  • Theorem 5.1
  • Theorem 5.2
  • Theorem B.1
  • proof
  • Theorem B.1
  • proof
  • Lemma B.1
  • proof
  • Lemma B.2
  • ...and 5 more