Table of Contents
Fetching ...

Guiding Neural Collapse: Optimising Towards the Nearest Simplex Equiangular Tight Frame

Evan Markou, Thalaiyasingam Ajanthan, Stephen Gould

TL;DR

This work introduces the notion of nearest simplex ETF geometry for the penultimate layer features at any given training iteration, by formulating it as a Riemannian optimisation and shows that this approach accelerates convergence and enhances training stability.

Abstract

Neural Collapse (NC) is a recently observed phenomenon in neural networks that characterises the solution space of the final classifier layer when trained until zero training loss. Specifically, NC suggests that the final classifier layer converges to a Simplex Equiangular Tight Frame (ETF), which maximally separates the weights corresponding to each class. By duality, the penultimate layer feature means also converge to the same simplex ETF. Since this simple symmetric structure is optimal, our idea is to utilise this property to improve convergence speed. Specifically, we introduce the notion of nearest simplex ETF geometry for the penultimate layer features at any given training iteration, by formulating it as a Riemannian optimisation. Then, at each iteration, the classifier weights are implicitly set to the nearest simplex ETF by solving this inner-optimisation, which is encapsulated within a declarative node to allow backpropagation. Our experiments on synthetic and real-world architectures for classification tasks demonstrate that our approach accelerates convergence and enhances training stability.

Guiding Neural Collapse: Optimising Towards the Nearest Simplex Equiangular Tight Frame

TL;DR

This work introduces the notion of nearest simplex ETF geometry for the penultimate layer features at any given training iteration, by formulating it as a Riemannian optimisation and shows that this approach accelerates convergence and enhances training stability.

Abstract

Neural Collapse (NC) is a recently observed phenomenon in neural networks that characterises the solution space of the final classifier layer when trained until zero training loss. Specifically, NC suggests that the final classifier layer converges to a Simplex Equiangular Tight Frame (ETF), which maximally separates the weights corresponding to each class. By duality, the penultimate layer feature means also converge to the same simplex ETF. Since this simple symmetric structure is optimal, our idea is to utilise this property to improve convergence speed. Specifically, we introduce the notion of nearest simplex ETF geometry for the penultimate layer features at any given training iteration, by formulating it as a Riemannian optimisation. Then, at each iteration, the classifier weights are implicitly set to the nearest simplex ETF by solving this inner-optimisation, which is encapsulated within a declarative node to allow backpropagation. Our experiments on synthetic and real-world architectures for classification tasks demonstrate that our approach accelerates convergence and enhances training stability.

Paper Structure

This paper contains 25 sections, 3 theorems, 51 equations, 18 figures, 3 tables.

Key Result

Proposition 1

Consider the optimisation problem in Equation eqn:closestetfprox. Assume that the solution exists and that the objective function $f$ and the constraint function $J$ are twice differentiable in the neighbourhood of the solution. If the $\mathrm{rank}({\bm{A}}) = \frac{C(C+1)}{2}$ and ${\bm{G}}$ is n where, Here, the double dot product symbol $(:)$ denotes a tensor contraction on appropriate indic

Figures (18)

  • Figure 1: Schematic of our proposed architecture for optimising towards the nearest simplex ETF. The classifier weights ${\bm{W}} = {\bm{U}}^\star {\bm{M}}$ are an implicit function of the CNN features ${\bm{H}}$. Note that the parameters of the CNN are updated via two gradient paths from the loss function ${\cal L}$, a direct path (top) and an indirect path through ${\bm{U}}^\star$ (bottom).
  • Figure 2: UFM-10 results. In all plots, the x-axis represents the number of epochs, except for plot (c), where the x-axis denotes the number of training examples.
  • Figure 3: The evolution of convergence measured in top-1 accuracy of the UFM as we increase the number of classes, plotted for the first 800 epochs. We omit the rest of the epochs as all methods have converged and have identical results.
  • Figure 4: ImageNet results on ResNet-50. In all plots, the x-axis represents the number of epochs, except for plot (c), where the x-axis denotes the number of training examples.
  • Figure 5: CIFAR100 computational cost results on ResNet-50. In (a), we plot the forward pass time for each method. For the implicit ETF method, which has dynamic computation times, we also include the mean and median time values. In (b), we plot the computational cost for each forward and backward pass across methods. For the implicit ETF forward pass, we have taken its median time. The notation is as follows: S/F = Standard Forward Pass, S/B = Standard Backward Pass, F/F = Fixed ETF Forward Pass, F/B = Fixed ETF Backward Pass, I/F = Implicit ETF Forward Pass, and I/B = Implicit ETF Backward Pass.
  • ...and 13 more figures

Theorems & Definitions (4)

  • Proposition 1: Following directly from Proposition 4.5 in ddn-original
  • Corollary 1: magnus-matrix-calculus
  • Proposition 2: Lagrange multiplier functions for Stiefel Manifolds
  • proof