Table of Contents
Fetching ...

Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation

Hongyu Cao, Yuxuan Wu, Yucheng Cai, Xianyu Zhao, Zhijian Ou

TL;DR

This work tackles the challenge of end-to-end training for retrieval-augmented generation by marginalizing over discrete passages with a joint stochastic approximation (JSA) approach. JSA-RAG introduces a prior retriever, a generator, and a posterior retriever trained via MIS-based E-steps and EM-like M-steps, optimizing the marginal likelihood $p_\theta(y|x)$ with unbiased, low-variance gradients. It also explores index rebuilding and passage concatenation to boost training efficiency and decoding performance. Across five datasets and two tasks (ODQA and knowledge-grounded dialogs), JSA-RAG consistently outperforms vanilla RAG and VRAG, delivering significant gains in end-to-end generation and retrieval recall, while maintaining comparable training cost. The results highlight JSA-RAG's potential for principled, scalable end-to-end RAG optimization and its applicability to advanced knowledge-intensive applications.

Abstract

Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.

Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation

TL;DR

This work tackles the challenge of end-to-end training for retrieval-augmented generation by marginalizing over discrete passages with a joint stochastic approximation (JSA) approach. JSA-RAG introduces a prior retriever, a generator, and a posterior retriever trained via MIS-based E-steps and EM-like M-steps, optimizing the marginal likelihood with unbiased, low-variance gradients. It also explores index rebuilding and passage concatenation to boost training efficiency and decoding performance. Across five datasets and two tasks (ODQA and knowledge-grounded dialogs), JSA-RAG consistently outperforms vanilla RAG and VRAG, delivering significant gains in end-to-end generation and retrieval recall, while maintaining comparable training cost. The results highlight JSA-RAG's potential for principled, scalable end-to-end RAG optimization and its applicability to advanced knowledge-intensive applications.

Abstract

Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.

Paper Structure

This paper contains 18 sections, 7 equations, 2 figures, 9 tables, 1 algorithm.

Figures (2)

  • Figure 1: Overview of JSA-RAG. 1) In addition to the (prior) retriever and generator, JSA-RAG introduces an (auxiliary) posterior retriever. 2) During training, the posterior retriever proposes relevant passages, which get accepted or rejected according to the probabilities calculated from the three components. The blue dashed line shows such Metropolis independence sampling (MIS), which is a Monte Carlo approximation of the E-step in EM. 3) The filtered passages are then treated as pseudo labels, as shown by the red dotted line. 4) Given the pseudo labels, we can calculate the gradients for prior retriever, posterior retriever, and generator, respectively, and proceed with parameter updating, very similar to perform supervised training, like the M-step in EM.
  • Figure 2: Comparison of the gradient norms from the posterior retriever.