Streamlining the Collaborative Chain of Models into A Single Forward Pass in Generation-Based Tasks
Yuanjie Lyu, Chao Zhang, Yuhao Chen, Yong Chen, Tong Xu
TL;DR
This work targets inefficiencies in chain-of-model generation by enabling direct KV hidden state sharing across models via a prompt-tuning-based method called FTHSS. FTHSS trains downstream models to consume upstream KV states, with online recomputation of KV during training to save storage, and uses reordered inputs and cascade attention masks to maintain compatibility across rounds. Empirical results on four tasks show FTHSS matches the chain’s performance while substantially reducing inference latency and KV-cache storage, benefiting both single-round and multi-round deployments on a single device. The approach advances practical deployment of retrieval-augmented and agent-based systems, though it may not apply to closed APIs and very large models without further adaptation; future work may extend scalability and generalizability to larger architectures and datasets.
Abstract
In Retrieval-Augmented Generation (RAG) and agent-based frameworks, the "Chain of Models" approach is widely used, where multiple specialized models work sequentially on distinct sub-tasks. This approach is effective but increases resource demands as each model must be deployed separately. Recent advancements attempt to address this by applying prompt tuning, which allows a shared base model to adapt to multiple tasks with minimal parameter changes. However, a key challenge remains: intermediate outputs, passed between models as plain text, require recomputation of hidden states (i.e., Key and Value (KV) states in Transformers) during inference. In this paper, we introduce FTHSS, a novel prompt-tuning method that enables models to share KV hidden states, eliminating redundant forward passes and reducing KV cache storage. By modifying input and attention masks during training, FTHSS allows models to effectively utilize KV hidden states from prior models in both single- and multi-round scenarios. Empirical results on four tasks show that FTHSS matches the performance of traditional model chains while improving inference efficiency.
