Early Neuron Alignment in Two-layer ReLU Networks with Small Initialization
Hancheng Min, Enrique Mallada, René Vidal
TL;DR
The paper addresses how gradient flow trains a two-layer ReLU network for binary classification when initialization is small and data are well-separated. It introduces a finite-$\epsilon$ analysis that identifies an early alignment phase where first-layer neurons align with data centers in cones $\mathcal{S}_+$ or $\mathcal{S}_-$, yielding a rigorous bound $t_1=O\left(\frac{\log n}{\sqrt{\mu}}\right)$. After alignment, training effectively decouples into two linear subnetworks, leading to $O\left(\frac{1}{t}\right)$ loss decay and a near-low-rank first-layer weight matrix, with a stable rank bound of at most $2$. The results are complemented by MNIST experiments that illustrate the predicted alignment and convergence dynamics, and the analysis clarifies the role of data separation $\mu$ and initialization scale in the training behavior and implicit bias of the model.
Abstract
This paper studies the problem of training a two-layer ReLU network for binary classification using gradient flow with small initialization. We consider a training dataset with well-separated input vectors: Any pair of input data with the same label are positively correlated, and any pair with different labels are negatively correlated. Our analysis shows that, during the early phase of training, neurons in the first layer try to align with either the positive data or the negative data, depending on its corresponding weight on the second layer. A careful analysis of the neurons' directional dynamics allows us to provide an $\mathcal{O}(\frac{\log n}{\sqrtμ})$ upper bound on the time it takes for all neurons to achieve good alignment with the input data, where $n$ is the number of data points and $μ$ measures how well the data are separated. After the early alignment phase, the loss converges to zero at a $\mathcal{O}(\frac{1}{t})$ rate, and the weight matrix on the first layer is approximately low-rank. Numerical experiments on the MNIST dataset illustrate our theoretical findings.
