Improving End-to-End Training of Retrieval-Augmented Generation Models via Joint Stochastic Approximation
Hongyu Cao, Yuxuan Wu, Yucheng Cai, Xianyu Zhao, Zhijian Ou
TL;DR
This work tackles the challenge of end-to-end training for retrieval-augmented generation by marginalizing over discrete passages with a joint stochastic approximation (JSA) approach. JSA-RAG introduces a prior retriever, a generator, and a posterior retriever trained via MIS-based E-steps and EM-like M-steps, optimizing the marginal likelihood $p_\theta(y|x)$ with unbiased, low-variance gradients. It also explores index rebuilding and passage concatenation to boost training efficiency and decoding performance. Across five datasets and two tasks (ODQA and knowledge-grounded dialogs), JSA-RAG consistently outperforms vanilla RAG and VRAG, delivering significant gains in end-to-end generation and retrieval recall, while maintaining comparable training cost. The results highlight JSA-RAG's potential for principled, scalable end-to-end RAG optimization and its applicability to advanced knowledge-intensive applications.
Abstract
Retrieval-augmented generation (RAG) has become a widely recognized paradigm to combine parametric memory with non-parametric memories. An RAG model consists of two serial connecting components (retriever and generator). A major challenge in end-to-end optimization of the RAG model is that marginalization over relevant passages (modeled as discrete latent variables) from a knowledge base is required. Traditional top-K marginalization and variational RAG (VRAG) suffer from biased or high-variance gradient estimates. In this paper, we propose and develop joint stochastic approximation (JSA) based end-to-end training of RAG, which is referred to as JSA-RAG. The JSA algorithm is a stochastic extension of the EM (expectation-maximization) algorithm and is particularly powerful in estimating discrete latent variable models. Extensive experiments are conducted on five datasets for two tasks (open-domain question answering, knowledge-grounded dialogs) and show that JSA-RAG significantly outperforms both vanilla RAG and VRAG. Further analysis shows the efficacy of JSA-RAG from the perspectives of generation, retrieval, and low-variance gradient estimate.
