Stochastic RAG: End-to-End Retrieval-Augmented Generation through Expected Utility Maximization
Hamed Zamani, Michael Bendersky
TL;DR
The paper tackles the challenge of end-to-end optimization for retrieval-augmented generation (RAG) by relaxing non-differentiable retrieval steps and independence assumptions. It treats retrieval as a stochastic sampling without replacement and optimizes an expected utility objective using straight-through Gumbel-top-k to approximate sampling without replacement, enabling differentiable learning. The core objective, $\text{RAG Expected Utility} = \frac{1}{n} \sum_{(x,y)\in T} \sum_{\hat{y}\in \mathcal{Y}} U(y,\hat{y}) p(\hat{y}|x; G_\theta, R_\phi)$, factors through $p(\hat{y}|x; G_\theta, R_\phi)= \sum_{\mathbf{d}\in \pi_k(C)} p(\hat{y}|x,\mathbf{d}; G_\theta) p(\mathbf{d}|x; R_\phi)$ and uses Ancestral Gumbel-Top-$k$ for differentiable sampling. Applied to FiD-Light on the KILT benchmark, the method achieves state-of-the-art performance on six of seven datasets, demonstrating robustness across model sizes and tasks. The framework supports arbitrary downstream metrics and has potential to improve grounding, faithfulness, and output diversity in RAG systems.
Abstract
This paper introduces Stochastic RAG--a novel approach for end-to-end optimization of retrieval-augmented generation (RAG) models that relaxes the simplifying assumptions of marginalization and document independence, made in most prior work. Stochastic RAG casts the retrieval process in RAG as a stochastic sampling without replacement process. Through this formulation, we employ straight-through Gumbel-top-k that provides a differentiable approximation for sampling without replacement and enables effective end-to-end optimization for RAG. We conduct extensive experiments on seven diverse datasets on a wide range of tasks, from open-domain question answering to fact verification to slot-filling for relation extraction and to dialogue systems. By applying this optimization method to a recent and effective RAG model, we advance state-of-the-art results on six out of seven datasets.
