Table of Contents
Fetching ...

A Mean Field Ansatz for Zero-Shot Weight Transfer

Xingyuan Chen, Wenwei Kuang, Lei Deng, Wei Han, Bo Bai, Goncalo dos Reis

TL;DR

The paper proposes a row-column (RC) mean-field ansatz to theoretically justify zero-shot weight transfer in neural networks, modeling weights as a structured joint distribution whose empirical measure evolves under training. By decomposing weight matrices into RC components, the authors show the weight-transfer process can be viewed as sampling from a limit RC-measure, enabling width-expansion transfers between models of different sizes. Empirical validation on MLPs (CIFAR-10) and large language models (GPT-3, Llama-3.1) demonstrates RC-consistent correlation patterns and successful weight transfer, supporting the mean-field perspective as a mechanism behind model growth and pruning. The work also provides extensive appendix material detailing the RC construction, initialization, and extensions to other architectures, while acknowledging questions about existence/uniqueness of limit measures and convergence rates in practical settings.

Abstract

The pre-training cost of large language models (LLMs) is prohibitive. One cutting-edge approach to reduce the cost is zero-shot weight transfer, also known as model growth for some cases, which magically transfers the weights trained in a small model to a large model. However, there are still some theoretical mysteries behind the weight transfer. In this paper, inspired by prior applications of mean field theory to neural network dynamics, we introduce a mean field ansatz to provide a theoretical explanation for weight transfer. Specifically, we propose the row-column (RC) ansatz under the mean field point of view, which describes the measure structure of the weights in the neural network (NN) and admits a close measure dynamic. Thus, the weights of different sizes NN admit a common distribution under proper assumptions, and weight transfer methods can be viewed as sampling methods. We empirically validate the RC ansatz by exploring simple MLP examples and LLMs such as GPT-3 and Llama-3.1. We show the mean-field point of view is adequate under suitable assumptions which can provide theoretical support for zero-shot weight transfer.

A Mean Field Ansatz for Zero-Shot Weight Transfer

TL;DR

The paper proposes a row-column (RC) mean-field ansatz to theoretically justify zero-shot weight transfer in neural networks, modeling weights as a structured joint distribution whose empirical measure evolves under training. By decomposing weight matrices into RC components, the authors show the weight-transfer process can be viewed as sampling from a limit RC-measure, enabling width-expansion transfers between models of different sizes. Empirical validation on MLPs (CIFAR-10) and large language models (GPT-3, Llama-3.1) demonstrates RC-consistent correlation patterns and successful weight transfer, supporting the mean-field perspective as a mechanism behind model growth and pruning. The work also provides extensive appendix material detailing the RC construction, initialization, and extensions to other architectures, while acknowledging questions about existence/uniqueness of limit measures and convergence rates in practical settings.

Abstract

The pre-training cost of large language models (LLMs) is prohibitive. One cutting-edge approach to reduce the cost is zero-shot weight transfer, also known as model growth for some cases, which magically transfers the weights trained in a small model to a large model. However, there are still some theoretical mysteries behind the weight transfer. In this paper, inspired by prior applications of mean field theory to neural network dynamics, we introduce a mean field ansatz to provide a theoretical explanation for weight transfer. Specifically, we propose the row-column (RC) ansatz under the mean field point of view, which describes the measure structure of the weights in the neural network (NN) and admits a close measure dynamic. Thus, the weights of different sizes NN admit a common distribution under proper assumptions, and weight transfer methods can be viewed as sampling methods. We empirically validate the RC ansatz by exploring simple MLP examples and LLMs such as GPT-3 and Llama-3.1. We show the mean-field point of view is adequate under suitable assumptions which can provide theoretical support for zero-shot weight transfer.
Paper Structure (34 sections, 47 equations, 6 figures, 1 table, 1 algorithm)

This paper contains 34 sections, 47 equations, 6 figures, 1 table, 1 algorithm.

Figures (6)

  • Figure 4.1: Heat plots for the normalized value of the middle layer in a 3-layer NN with biases under 3 different setups with $N=300$, above are the results at initialization and below are the results after training for 10 epochs. Notice that at initialization, the SP and MF take a uniform distribution, and the $\mu$ P takes a Gaussian distribution.
  • Figure 4.2: Correlation matrices for a 7-layer MFNN example under different bias settings with $N=1000$. (a) Non-constant trainable bias; (b) Constant trainable bias; (c) no bias.
  • Figure 4.3: Training loss and test accuracy of the 3-layer NN example under different settings with $\eta = 0.1$. (a) and (d) are of different initial distributions. (b) and (e) are the weight transfer results from the $N=100$ model to the $N=1000$ model at the 4-$th$ epoch with different random rates $r_1$ and norm rates $r_2$. (c) and (f) are the weight transfer from the the $N=1000$ model to the $N=100$ model with different random rates $r_1$ and norm rates $r_2$.
  • Figure 4.4: Maximum absolute correlation coefficients for different LLMs, where the $.C$ denotes the mean value on rows, $.R$ denotes the mean value on columns, we show the maximum absolute value of the correlation over different blocks in each LLM.
  • Figure 4.5: Distribution of the row means in the MLP layer weight matrix $\mathbb{R}^{N\times 4N}$ for the small model $N=256$ and large model $N=1024$ in different training step $T\in\{4k,12k,20k,32k\}$ in different blocks with the weight taking i.i.d initialisation of $\mathcal{N}(1,3)$.
  • ...and 1 more figures

Theorems & Definitions (10)

  • Example 1.1: Simple 3-layer NN
  • Example 1.2: 2-layer MFNN from 2019-mf-lln
  • Definition 3.1: $\gamma$ notation
  • Definition 3.2: $\Gamma$ set
  • Example 3.3: $\gamma$ notation and $\Gamma$ set
  • Example 3.5: Weight transfer for Example \ref{['example: gamma set']}
  • Example 6.1: Weight transfer for Example \ref{['example: gamma set']}
  • Example 6.2
  • Example 6.3
  • Example 6.4