Table of Contents
Fetching ...

CAMEx: Curvature-aware Merging of Experts

Dung V. Nguyen, Minh H. Nguyen, Luc Q. Nguyen, Rachel S. Y. Teo, Tan M. Nguyen, Linh Duy Tran

TL;DR

CAMEx introduces curvature-aware merging for Sparse Mixture of Experts by leveraging natural gradients to account for the non-Euclidean geometry of neural parameter spaces. It replaces costly Fisher-misher-based curvature estimates with learned, efficiently approximated curvature matrices via Kronecker-factor decompositions and a test-time reparameterization, enabling scalable merging during pre-training and fine-tuning. The method includes a dynamic architecture that maintains the same active-expert capacity while reducing parameters and FLOPs, and provides both theoretical alignment with the empirical Fisher and extensive empirical gains across language modeling, QA, classification, and vision tasks. The results demonstrate faster convergence and improved generalization without prohibitive memory overhead, with public code released to foster reproducibility and adoption.

Abstract

Existing methods for merging experts during model training and fine-tuning predominantly rely on Euclidean geometry, which assumes a flat parameter space. This assumption can limit the model's generalization ability, especially during the pre-training phase, where the parameter manifold might exhibit more complex curvature. Curvature-aware merging methods typically require additional information and computational resources to approximate the Fisher Information Matrix, adding memory overhead. In this paper, we introduce CAMEx (Curvature-Aware Merging of Experts), a novel expert merging protocol that incorporates natural gradients to account for the non-Euclidean curvature of the parameter manifold. By leveraging natural gradients, CAMEx adapts more effectively to the structure of the parameter space, improving alignment between model updates and the manifold's geometry. This approach enhances both pre-training and fine-tuning, resulting in better optimization trajectories and improved generalization without the substantial memory overhead typically associated with curvature-aware methods. Our contributions are threefold: (1) CAMEx significantly outperforms traditional Euclidean-based expert merging techniques across various natural language processing tasks, leading to enhanced performance during pre-training and fine-tuning; (2) we introduce a dynamic merging architecture that optimizes resource utilization, achieving high performance while reducing computational costs, facilitating efficient scaling of large language models; and (3) we provide both theoretical and empirical evidence to demonstrate the efficiency of our proposed method. The code is publicly available at: https://github.com/kpup1710/CAMEx.

CAMEx: Curvature-aware Merging of Experts

TL;DR

CAMEx introduces curvature-aware merging for Sparse Mixture of Experts by leveraging natural gradients to account for the non-Euclidean geometry of neural parameter spaces. It replaces costly Fisher-misher-based curvature estimates with learned, efficiently approximated curvature matrices via Kronecker-factor decompositions and a test-time reparameterization, enabling scalable merging during pre-training and fine-tuning. The method includes a dynamic architecture that maintains the same active-expert capacity while reducing parameters and FLOPs, and provides both theoretical alignment with the empirical Fisher and extensive empirical gains across language modeling, QA, classification, and vision tasks. The results demonstrate faster convergence and improved generalization without prohibitive memory overhead, with public code released to foster reproducibility and adoption.

Abstract

Existing methods for merging experts during model training and fine-tuning predominantly rely on Euclidean geometry, which assumes a flat parameter space. This assumption can limit the model's generalization ability, especially during the pre-training phase, where the parameter manifold might exhibit more complex curvature. Curvature-aware merging methods typically require additional information and computational resources to approximate the Fisher Information Matrix, adding memory overhead. In this paper, we introduce CAMEx (Curvature-Aware Merging of Experts), a novel expert merging protocol that incorporates natural gradients to account for the non-Euclidean curvature of the parameter manifold. By leveraging natural gradients, CAMEx adapts more effectively to the structure of the parameter space, improving alignment between model updates and the manifold's geometry. This approach enhances both pre-training and fine-tuning, resulting in better optimization trajectories and improved generalization without the substantial memory overhead typically associated with curvature-aware methods. Our contributions are threefold: (1) CAMEx significantly outperforms traditional Euclidean-based expert merging techniques across various natural language processing tasks, leading to enhanced performance during pre-training and fine-tuning; (2) we introduce a dynamic merging architecture that optimizes resource utilization, achieving high performance while reducing computational costs, facilitating efficient scaling of large language models; and (3) we provide both theoretical and empirical evidence to demonstrate the efficiency of our proposed method. The code is publicly available at: https://github.com/kpup1710/CAMEx.

Paper Structure

This paper contains 40 sections, 30 equations, 9 figures, 19 tables, 1 algorithm.

Figures (9)

  • Figure 1: Overview of CAMEx for a causal language modeling SMoE. The experts are merged through the router scores and the curvature-matrix $\mathbf{M}$. During the merging protocol, we can generate the masks for the domain-vectors, denoted as $\gamma_i$, such as Ties or Dare. We follow the causal segmenting pipeline from zhong2024lory to achieve both memory efficiency and causal information constraints. Note that stop gradient operator is applied for the first segment router scores.
  • Figure 2: Overall architecture of different SMoE layers. The figure presents the vanilla SMoE layer on the left, the merging expert layer in the middle, and our proposed dynamic merging SMoE layer on the right. Our architecture reduces the number of parameters compared to the other two, while maintaining the same number of activated neurons per layer. Importantly, despite the dynamic merging mechanism, our architecture preserves the same number of experts at each layer as the other SMoE architectures, ensuring comparable model capacity, i.e., the number of activated parameters per layer.
  • Figure 3: Perplexity of GPT2-small variants starting at the tenth epoch.
  • Figure 4: Impact of the $\alpha$ parameter on Curvature-Aware method performance across NLP tasks. We observe that the scaling factors that are within the range $[0.8,1]$ consistently improve model's performance.
  • Figure 5: Impact of the Kronecker rank of curvature matrix on model's performance. We observe that as the rank increases the performance drops and then saturates. However, we would like to note that this curve might change depending on the downstream tasks and the merging protocol.
  • ...and 4 more figures