Table of Contents
Fetching ...

Streamlining the Collaborative Chain of Models into A Single Forward Pass in Generation-Based Tasks

Yuanjie Lyu, Chao Zhang, Yuhao Chen, Yong Chen, Tong Xu

TL;DR

This work targets inefficiencies in chain-of-model generation by enabling direct KV hidden state sharing across models via a prompt-tuning-based method called FTHSS. FTHSS trains downstream models to consume upstream KV states, with online recomputation of KV during training to save storage, and uses reordered inputs and cascade attention masks to maintain compatibility across rounds. Empirical results on four tasks show FTHSS matches the chain’s performance while substantially reducing inference latency and KV-cache storage, benefiting both single-round and multi-round deployments on a single device. The approach advances practical deployment of retrieval-augmented and agent-based systems, though it may not apply to closed APIs and very large models without further adaptation; future work may extend scalability and generalizability to larger architectures and datasets.

Abstract

In Retrieval-Augmented Generation (RAG) and agent-based frameworks, the "Chain of Models" approach is widely used, where multiple specialized models work sequentially on distinct sub-tasks. This approach is effective but increases resource demands as each model must be deployed separately. Recent advancements attempt to address this by applying prompt tuning, which allows a shared base model to adapt to multiple tasks with minimal parameter changes. However, a key challenge remains: intermediate outputs, passed between models as plain text, require recomputation of hidden states (i.e., Key and Value (KV) states in Transformers) during inference. In this paper, we introduce FTHSS, a novel prompt-tuning method that enables models to share KV hidden states, eliminating redundant forward passes and reducing KV cache storage. By modifying input and attention masks during training, FTHSS allows models to effectively utilize KV hidden states from prior models in both single- and multi-round scenarios. Empirical results on four tasks show that FTHSS matches the performance of traditional model chains while improving inference efficiency.

Streamlining the Collaborative Chain of Models into A Single Forward Pass in Generation-Based Tasks

TL;DR

This work targets inefficiencies in chain-of-model generation by enabling direct KV hidden state sharing across models via a prompt-tuning-based method called FTHSS. FTHSS trains downstream models to consume upstream KV states, with online recomputation of KV during training to save storage, and uses reordered inputs and cascade attention masks to maintain compatibility across rounds. Empirical results on four tasks show FTHSS matches the chain’s performance while substantially reducing inference latency and KV-cache storage, benefiting both single-round and multi-round deployments on a single device. The approach advances practical deployment of retrieval-augmented and agent-based systems, though it may not apply to closed APIs and very large models without further adaptation; future work may extend scalability and generalizability to larger architectures and datasets.

Abstract

In Retrieval-Augmented Generation (RAG) and agent-based frameworks, the "Chain of Models" approach is widely used, where multiple specialized models work sequentially on distinct sub-tasks. This approach is effective but increases resource demands as each model must be deployed separately. Recent advancements attempt to address this by applying prompt tuning, which allows a shared base model to adapt to multiple tasks with minimal parameter changes. However, a key challenge remains: intermediate outputs, passed between models as plain text, require recomputation of hidden states (i.e., Key and Value (KV) states in Transformers) during inference. In this paper, we introduce FTHSS, a novel prompt-tuning method that enables models to share KV hidden states, eliminating redundant forward passes and reducing KV cache storage. By modifying input and attention masks during training, FTHSS allows models to effectively utilize KV hidden states from prior models in both single- and multi-round scenarios. Empirical results on four tasks show that FTHSS matches the performance of traditional model chains while improving inference efficiency.

Paper Structure

This paper contains 36 sections, 10 equations, 4 figures, 10 tables.

Figures (4)

  • Figure 1: Comparison of "Chain of Models" (a) and FTHSS (b): In (a), models sequentially pass outputs as plain text, requiring KV recomputation. In (b), FTHSS shares KV hidden states, reducing redundant forward passes. PEFT methods allow the deployment of multiple models on a single device, with parameters changing, so there is no communication overhead for hidden states.
  • Figure 2: An example of fine-tuning model B in the model chain A → B. For simplicity, the unique inputs of model A and model B are omitted. Left: Offline fine-tuning, where the output KV hidden states of fully trained model A are stored and used as input for model B. Middle: Online, where the output KV hidden states of model A are recalculated in memory. Right: We calculate the output KV hidden states of model A in memory and fine-tune model B by adjusting the attention mask for each layer. We use the online training strategy in practical applications.
  • Figure 3: Cascade attention mask for every layer in the multi-round scenario.
  • Figure 4: Performance comparison of three fine-tuning strategies on Context Compression & QA task: (1) No additional fine-tuning, using noisy KV hidden states directly; (2) FTHSS (5000 samples), where the standard-prompt-tuning model is fine-tuned on 5,000 examples; and (3) Fully FTHSS, where the base model undergoes full-dataset fine-tuning.