Table of Contents
Fetching ...

Efficient Model Compression Techniques with FishLeg

Jamie McGowan, Wei Sheng Lai, Weibin Chen, Henry Aldridge, Jools Clarke, Jezabel Garcia, Rui Xia, Yilei Liang, Guillaume Hennequin, Alberto Bernacchia

TL;DR

The paper tackles the challenge of pruning large neural networks efficiently by exploiting second-order information without prohibitive memory costs. It introduces FishLeg Surgeon (FLS), which uses the FishLeg optimizer to meta-learn a parametric inverse Fisher F_gamma^{-1} ≈ Q(lambda) and to update curvature estimates online during pruning, avoiding full re-computation. Key contributions include a memory-efficient block-diagonal parameterization for Q, initialization and preconditioning strategies that speed convergence, and empirical validation showing improved accuracy at high sparsity on ResNet18/CIFAR-10 and TinyIM. This approach integrates second-order optimization with model compression, enabling scalable, accurate pruning suitable for resource-constrained deployments.

Abstract

In many domains, the most successful AI models tend to be the largest, indeed often too large to be handled by AI players with limited computational resources. To mitigate this, a number of compression methods have been developed, including methods that prune the network down to high sparsity whilst retaining performance. The best-performing pruning techniques are often those that use second-order curvature information (such as an estimate of the Fisher information matrix) to score the importance of each weight and to predict the optimal compensation for weight deletion. However, these methods are difficult to scale to high-dimensional parameter spaces without making heavy approximations. Here, we propose the FishLeg surgeon (FLS), a new second-order pruning method based on the Fisher-Legendre (FishLeg) optimizer. At the heart of FishLeg is a meta-learning approach to amortising the action of the inverse FIM, which brings a number of advantages. Firstly, the parameterisation enables the use of flexible tensor factorisation techniques to improve computational and memory efficiency without sacrificing much accuracy, alleviating challenges associated with scalability of most second-order pruning methods. Secondly, directly estimating the inverse FIM leads to less sensitivity to the amplification of stochasticity during inversion, thereby resulting in more precise estimates. Thirdly, our approach also allows for progressive assimilation of the curvature into the parameterisation. In the gradual pruning regime, this results in a more efficient estimate refinement as opposed to re-estimation. We find that FishLeg achieves higher or comparable performance against two common baselines in the area, most notably in the high sparsity regime when considering a ResNet18 model on CIFAR-10 (84% accuracy at 95% sparsity vs 60% for OBS) and TinyIM (53% accuracy at 80% sparsity vs 48% for OBS).

Efficient Model Compression Techniques with FishLeg

TL;DR

The paper tackles the challenge of pruning large neural networks efficiently by exploiting second-order information without prohibitive memory costs. It introduces FishLeg Surgeon (FLS), which uses the FishLeg optimizer to meta-learn a parametric inverse Fisher F_gamma^{-1} ≈ Q(lambda) and to update curvature estimates online during pruning, avoiding full re-computation. Key contributions include a memory-efficient block-diagonal parameterization for Q, initialization and preconditioning strategies that speed convergence, and empirical validation showing improved accuracy at high sparsity on ResNet18/CIFAR-10 and TinyIM. This approach integrates second-order optimization with model compression, enabling scalable, accurate pruning suitable for resource-constrained deployments.

Abstract

In many domains, the most successful AI models tend to be the largest, indeed often too large to be handled by AI players with limited computational resources. To mitigate this, a number of compression methods have been developed, including methods that prune the network down to high sparsity whilst retaining performance. The best-performing pruning techniques are often those that use second-order curvature information (such as an estimate of the Fisher information matrix) to score the importance of each weight and to predict the optimal compensation for weight deletion. However, these methods are difficult to scale to high-dimensional parameter spaces without making heavy approximations. Here, we propose the FishLeg surgeon (FLS), a new second-order pruning method based on the Fisher-Legendre (FishLeg) optimizer. At the heart of FishLeg is a meta-learning approach to amortising the action of the inverse FIM, which brings a number of advantages. Firstly, the parameterisation enables the use of flexible tensor factorisation techniques to improve computational and memory efficiency without sacrificing much accuracy, alleviating challenges associated with scalability of most second-order pruning methods. Secondly, directly estimating the inverse FIM leads to less sensitivity to the amplification of stochasticity during inversion, thereby resulting in more precise estimates. Thirdly, our approach also allows for progressive assimilation of the curvature into the parameterisation. In the gradual pruning regime, this results in a more efficient estimate refinement as opposed to re-estimation. We find that FishLeg achieves higher or comparable performance against two common baselines in the area, most notably in the high sparsity regime when considering a ResNet18 model on CIFAR-10 (84% accuracy at 95% sparsity vs 60% for OBS) and TinyIM (53% accuracy at 80% sparsity vs 48% for OBS).

Paper Structure

This paper contains 16 sections, 23 equations, 5 figures, 1 table, 1 algorithm.

Figures (5)

  • Figure 1: The initialization of $Q(\boldsymbol{\lambda})$ matters much. In this toy experiment, the true Fisher matrix ($n=100$) was chosen so that its $i^\text{th}$ eigenvalue is $\xi_i \triangleq 1/i^2$, and the damping parameter $\gamma$ was set to $10^{-3}$. Thus, the eigenvalues of $F_\gamma^{-1}$ lie roughly in the $[1 - 1000]$ range. The auxiliary loss $\mathcal{A}(Q) = \frac{1}{2} \text{Tr}(QFQ) - \text{Tr}(Q)$ (left) was minimized by gradient descent w.r.t. the Cholesky factor of $Q(\boldsymbol{\lambda})$, initialized such that $Q(\boldsymbol{\lambda})=I$ (black) or $Q(\boldsymbol{\lambda})=\gamma^{-1} I = 1000 \times I$ (red). The learning rate was optimized separately for each case. This simulation shows that it is clearly better to initialize $Q$ to be large rather than small. Indeed, a simple derivation shows that each eigenvalue $\beta_i$ of $Q$ approaches its target $1/(\xi_i + \gamma)$ at a speed proportional to $(\xi_i + \gamma)$ (\ref{['eq:aux_dynamics']}). In other words, the eigenvalues of $Q$ that must end up large are also those that evolve the slowest. It, therefore makes sense to initialize them to be large so they have less to travel; the eigenvalues that must end up small will become small rapidly anyway. The right panels illustrate this behaviour by plotting the eigenvalues of $Q$ against their respective targets, at regular intervals during optimization (color-coded), for both initialization schemes. The auxiliary loss is minimized when $\beta_i = 1/(\xi_i+\gamma)$, i.e. when the dots lie along the identity line (dashed grey).
  • Figure 2: Test accuracy as a function of model sparsity for ResNet18 on CIFAR-10 (left) and TinyIM (right). Different pruning frameworks are used, which are magnitude pruning (blue), OBS (orange) and FLS (green).
  • Figure 3: 2:4 semi-structured pruning performance of ResNet18 model finetuned on CIFAR-10 and TinyIM data.
  • Figure 4: Assessing FishLeg's inverse curvature estimation in a controlled setting. In this figure, the true Fisher matrix $F \in \mathbb{R}^{100\times 100}$ is constructed to have a random orthonormal eigenbasis and eigenvalues $\lambda_i \propto e^{-i/30}$. All results are averaged over 20 independent realizations of the corresponding experiment with different random seeds. (A): standard affine-invariant Riemannian distance between FishLeg's $Q$ and $F_\gamma^{-1}$ ($\gamma=0.01$), as a function of the number of data mini-batches of size $m$ consumed so far. Each Adam step of auxiliary loss minimization consumes one minibatch. In this case, we use a full parameterization $Q = LL^\top$ that contains the solution $F^{-1}_\gamma$; in that case, FishLeg's inverse curvature estimation is consistent and the error goes to zero. As a baseline, we show the behaviour of a simple but biased estimator that estimates $F_\gamma$ on each new minibatch, inverts that noisy estimate, and averages the result over minibatches; inverting noisy estimates yields a bias that persists asymptotically. (B-D): In these panels, the inverse Fisher is estimated in structured form (B: diagonal; C: block-diagonal, 5 blocks; D: Kronecker product, $(5\times 5) \otimes (20 \times 20)$. This is done either by FishLeg assuming a correspondingly structured form for $Q$ (red), or by (i) approximating $F_\gamma$ in structured form for each minibatch (for the Kronecker approximation, we use a permuted SVD to find the nearest Kronecker product in the least-squares sense; van1993approximation), (ii) averaging over minibatches (for the Kronecker approximation the two factors are averaged separately, as in KFAC), and (iii) inverting the result (black; note that in this case, the inverse inherits the structure). We report the squared error between $Q\boldsymbol{u}$ and $F_\gamma^{-1} \boldsymbol{u}$, averaged over $\boldsymbol{u} \sim \mathcal{N}(0, \Sigma_u)$, and normalized by the average norm of $F_\gamma^{-1} \boldsymbol{u}$. Here, to reflect the need of accurately estimating the action of $F^{-1}_\gamma$ on the least salient parameter dimensions, we have chosen $\Sigma_u = F^{-1}$.
  • Figure 5: Ablation experiments on synthetic data in a toy setup to show: (A) the utility of preconditioning the auxiliary loss, (B) the predicted quality of the approximated Fisher in different scenario's, (C) the one-shot pruning performance of various Fisher approximations (including other baselines) and (D) the effect of implementing a block diagonal FishLeg approximation and it's comparison to oBERT (an OBS-derived approach) at various block sizes.