Table of Contents
Fetching ...

A Theoretical Framework for Modular Learning of Robust Generative Models

Corinna Cortes, Mehryar Mohri, Yutao Zhong

TL;DR

A theoretical framework for modular generative modeling where a set of pre-trained experts are combined via a gating mechanism and it is proved that this modular approach can theoretically outperform models retrained on aggregate data, with the gap characterized by the Jensen-Shannon Divergence.

Abstract

Training large-scale generative models is resource-intensive and relies heavily on heuristic dataset weighting. We address two fundamental questions: Can we train Large Language Models (LLMs) modularly-combining small, domain-specific experts to match monolithic performance-and can we do so robustly for any data mixture, eliminating heuristic tuning? We present a theoretical framework for modular generative modeling where a set of pre-trained experts are combined via a gating mechanism. We define the space of normalized gating functions, $G_{1}$, and formulate the problem as a minimax game to find a single robust gate that minimizes divergence to the worst-case data mixture. We prove the existence of such a robust gate using Kakutani's fixed-point theorem and show that modularity acts as a strong regularizer, with generalization bounds scaling with the lightweight gate's complexity. Furthermore, we prove that this modular approach can theoretically outperform models retrained on aggregate data, with the gap characterized by the Jensen-Shannon Divergence. Finally, we introduce a scalable Stochastic Primal-Dual algorithm and a Structural Distillation method for efficient inference. Empirical results on synthetic and real-world datasets confirm that our modular architecture effectively mitigates gradient conflict and can robustly outperform monolithic baselines.

A Theoretical Framework for Modular Learning of Robust Generative Models

TL;DR

A theoretical framework for modular generative modeling where a set of pre-trained experts are combined via a gating mechanism and it is proved that this modular approach can theoretically outperform models retrained on aggregate data, with the gap characterized by the Jensen-Shannon Divergence.

Abstract

Training large-scale generative models is resource-intensive and relies heavily on heuristic dataset weighting. We address two fundamental questions: Can we train Large Language Models (LLMs) modularly-combining small, domain-specific experts to match monolithic performance-and can we do so robustly for any data mixture, eliminating heuristic tuning? We present a theoretical framework for modular generative modeling where a set of pre-trained experts are combined via a gating mechanism. We define the space of normalized gating functions, , and formulate the problem as a minimax game to find a single robust gate that minimizes divergence to the worst-case data mixture. We prove the existence of such a robust gate using Kakutani's fixed-point theorem and show that modularity acts as a strong regularizer, with generalization bounds scaling with the lightweight gate's complexity. Furthermore, we prove that this modular approach can theoretically outperform models retrained on aggregate data, with the gap characterized by the Jensen-Shannon Divergence. Finally, we introduce a scalable Stochastic Primal-Dual algorithm and a Structural Distillation method for efficient inference. Empirical results on synthetic and real-world datasets confirm that our modular architecture effectively mitigates gradient conflict and can robustly outperform monolithic baselines.
Paper Structure (56 sections, 17 theorems, 94 equations, 10 figures, 2 tables, 6 algorithms)

This paper contains 56 sections, 17 theorems, 94 equations, 10 figures, 2 tables, 6 algorithms.

Key Result

Lemma 0

The family $\cG_{1}$ is non-empty, compact, and convex.

Figures (10)

  • Figure 1: Conceptual Architecture of the Modular Gated Solution.
  • Figure 2: Geometry of Prior Knowledge (Section \ref{['sec:prior-knowledge']}). The outer triangle represents the full probability simplex $\Delta$. The green region $\Lambda$ represents the subset of valid mixtures defined by prior knowledge. The algorithm projects the adversary's updates (red point) back onto $\Lambda$ (blue point), tightening the worst-case bound as per Theorem \ref{['th:quant-improvement']}.
  • Figure 3: Visualizing the JSD Gap. A gated model (blue) fits distinct modes perfectly by routing inputs. A retrained model (red) suffers from capacity interference, forcing an entropy increase proportional to the JSD.
  • Figure 4: Efficiency Strategies. (A) Monolithic Distillation trains a single large model to mimic the ensemble, discarding the original experts. (B) Structural Distillation trains a lightweight Causal Router to mimic only the gating decisions ($g^*$), preserving the original experts. This maintains modularity: upgrading an expert in (B) improves the system immediately.
  • Figure 5: Modularity overcomes gradient conflict. Left figure illustrates a comparison to the Fixed Smaller and Larger models (in Green), while the right figure illustrates the comparison to the Oracle models (in Red). Results are illustrated with lines for the mean values over 5 runs and standard deviations indicated with shaded regions. The Robust Gate (blue) maintains consistently low loss across all mixture weights. The Fixed models share the same consistent behavior but both at significantly higher loss values. The Oracle models in the right figure naturally obtains a better loss in the skewed distribution regions ($\lambda < 0.3$ and $\lambda > 0.7$), but both the Smaller (dashed) and Larger (solid) Oracles suffer from interference in the high-entropy region ($\lambda \approx 0.5$), forming a concave error curve. Remarkably, the modular system outperforms the monolithic Larger Oracle in this mixed regime despite having a significantly smaller total parameter count.
  • ...and 5 more figures

Theorems & Definitions (34)

  • Lemma 0
  • proof
  • Proposition 0: Fixed Mixture Guarantee
  • proof
  • theorem 1: Robust Existence
  • proof
  • theorem 2: Dominance of the Specialized Solution
  • proof
  • theorem 3: Quantitative Improvement in Game Value
  • proof
  • ...and 24 more