Table of Contents
Fetching ...

Learning to Decode Collaboratively with Multiple Language Models

Shannon Zejiang Shen, Hunter Lang, Bailin Wang, Yoon Kim, David Sontag

TL;DR

Co-Llm introduces a latent-variable framework for token-level collaboration between a base LLM and an assistant LLM, formalizing the choice of which model generates each token as $Z_t$ with $P(X,Z)=\prod_{t} P_{\theta}(Z_t|X_{<t}) P_{Z_t}(X_t|X_{<t})$ and $P(X)=\prod_{t} \sum_{Z_t} P_{\theta}(Z_t|X_{<t}) P_{Z_t}(X_t|X_{<t})$. It trains by maximizing the marginal likelihood, using a lightweight linear head on the base model to predict $Z_t$ and a greedy decoding policy controlled by a threshold $\eta$, enabling inference-time control over collaboration frequency. Empirically, Co-Llm improves performance on instruction following, mathematical reasoning, and biomedical QA tasks, including cross-domain and cross-scale model pairings, often matching or exceeding fine-tuning gains while reducing the number of calls to large models. The work highlights interpretable collaboration patterns (e.g., template-filling) and positions Co-Llm as a modular, cost-efficient approach to leveraging domain experts without retraining large LMs, with potential extensions to more LMs and more complex deferral strategies.

Abstract

We propose a method to teach multiple large language models (LLM) to collaborate by interleaving their generations at the token level. We model the decision of which LLM generates the next token as a latent variable. By optimizing the marginal likelihood of a training set under our latent variable model, the base LLM automatically learns when to generate itself and when to call on one of the ``assistant'' language models to generate, all without direct supervision. Token-level collaboration during decoding allows for a fusion of each model's expertise in a manner tailored to the specific task at hand. Our collaborative decoding is especially useful in cross-domain settings where a generalist base LLM learns to invoke domain expert models. On instruction-following, domain-specific QA, and reasoning tasks, we show that the performance of the joint system exceeds that of the individual models. Through qualitative analysis of the learned latent decisions, we show models trained with our method exhibit several interesting collaboration patterns, e.g., template-filling. Our code is available at https://github.com/clinicalml/co-llm.

Learning to Decode Collaboratively with Multiple Language Models

TL;DR

Co-Llm introduces a latent-variable framework for token-level collaboration between a base LLM and an assistant LLM, formalizing the choice of which model generates each token as with and . It trains by maximizing the marginal likelihood, using a lightweight linear head on the base model to predict and a greedy decoding policy controlled by a threshold , enabling inference-time control over collaboration frequency. Empirically, Co-Llm improves performance on instruction following, mathematical reasoning, and biomedical QA tasks, including cross-domain and cross-scale model pairings, often matching or exceeding fine-tuning gains while reducing the number of calls to large models. The work highlights interpretable collaboration patterns (e.g., template-filling) and positions Co-Llm as a modular, cost-efficient approach to leveraging domain experts without retraining large LMs, with potential extensions to more LMs and more complex deferral strategies.

Abstract

We propose a method to teach multiple large language models (LLM) to collaborate by interleaving their generations at the token level. We model the decision of which LLM generates the next token as a latent variable. By optimizing the marginal likelihood of a training set under our latent variable model, the base LLM automatically learns when to generate itself and when to call on one of the ``assistant'' language models to generate, all without direct supervision. Token-level collaboration during decoding allows for a fusion of each model's expertise in a manner tailored to the specific task at hand. Our collaborative decoding is especially useful in cross-domain settings where a generalist base LLM learns to invoke domain expert models. On instruction-following, domain-specific QA, and reasoning tasks, we show that the performance of the joint system exceeds that of the individual models. Through qualitative analysis of the learned latent decisions, we show models trained with our method exhibit several interesting collaboration patterns, e.g., template-filling. Our code is available at https://github.com/clinicalml/co-llm.
Paper Structure (38 sections, 8 equations, 4 figures, 12 tables, 1 algorithm)

This paper contains 38 sections, 8 equations, 4 figures, 12 tables, 1 algorithm.

Figures (4)

  • Figure 1: Example generations of our method, Co-Llm. Top: the base model generates the answer template and uses a larger Llama model to fill in factual knowledge; Bottom: the base model uses a math-specialized model as an "API" for computation. The assistant model generated the highlighted tokens because the base model learned to defer generation at those locations.
  • Figure 2: Illustration of the decoding procedure in Co-Llm, where a base ( Llama-7b) and assistant model (Meditron-70B) collaborate to generate a correct response for a medical question. For each token, the deferral controlpredicts the probability of switching to the assistant model to decode the next token given the context: it defers when the probability is above some threshold $\eta$ (indicated by A), and uses the decoded token as the context (highlighted with orange border). predicts to invoke the assistant model given the current context when suitable, and interleaves the generations from both models: Tokens highlighted with an orange border constitute the final generation. When using the base model alone, it may make factual mistakes (indicated by B); Co-Llm learns to use the assistant model at these positions to produce correct generations.
  • Figure 3: Performance of Co-Llm at different frequencies of deferral on GSM8k. There exists an optimal $f$ that the joint model achieves better performance than using either of them alone. Similar trend is observed in MATH and BioASQ, shown in \ref{['fig:thresholds-acc-more']} in Appendix.
  • Figure 4: Performance of Co-Llm at different frequencies of deferral on GSM8k, MATH and BioASQ. There exists an optimal $f$ that the joint model achieves better performance than using either of them alone.