Table of Contents
Fetching ...

An Augmented Backward-Corrected Projector Splitting Integrator for Dynamical Low-Rank Training

Jonas Kusch, Steffen Schotthöfer, Alexandra Walter

TL;DR

This work tackles memory-efficient neural network training via dynamical low-rank training (DLRT) by reformulating weight evolution on the low-rank manifold $Y = USV^{\top}$ and addressing robustness challenges of the Projector Splitting Integrator (PSI). It introduces Augmented Backward-Corrected PSI (abc-PSI), which augments the basis during the backward-corrected PSI to guarantee loss descent and enables rank adaptation through a truncation step, while reducing QR decompositions to a single one per iteration. The authors provide rigorous analyses including robust error bounds and convergence to a local optimum under standard smoothness and boundedness assumptions, and validate the method on MNIST and Vision Transformer fine-tuning tasks, achieving strong compression with competitive accuracy. Overall, abc-PSI offers a theoretically sound, computationally efficient, and practically effective approach for dynamical low-rank training and fine-tuning of large neural networks.

Abstract

Layer factorization has emerged as a widely used technique for training memory-efficient neural networks. However, layer factorization methods face several challenges, particularly a lack of robustness during the training process. To overcome this limitation, dynamical low-rank training methods have been developed, utilizing robust time integration techniques for low-rank matrix differential equations. Although these approaches facilitate efficient training, they still depend on computationally intensive QR and singular value decompositions of matrices with small rank. In this work, we introduce a novel low-rank training method that reduces the number of required QR decompositions. Our approach integrates an augmentation step into a projector-splitting scheme, ensuring convergence to a locally optimal solution. We provide a rigorous theoretical analysis of the proposed method and demonstrate its effectiveness across multiple benchmarks.

An Augmented Backward-Corrected Projector Splitting Integrator for Dynamical Low-Rank Training

TL;DR

This work tackles memory-efficient neural network training via dynamical low-rank training (DLRT) by reformulating weight evolution on the low-rank manifold and addressing robustness challenges of the Projector Splitting Integrator (PSI). It introduces Augmented Backward-Corrected PSI (abc-PSI), which augments the basis during the backward-corrected PSI to guarantee loss descent and enables rank adaptation through a truncation step, while reducing QR decompositions to a single one per iteration. The authors provide rigorous analyses including robust error bounds and convergence to a local optimum under standard smoothness and boundedness assumptions, and validate the method on MNIST and Vision Transformer fine-tuning tasks, achieving strong compression with competitive accuracy. Overall, abc-PSI offers a theoretically sound, computationally efficient, and practically effective approach for dynamical low-rank training and fine-tuning of large neural networks.

Abstract

Layer factorization has emerged as a widely used technique for training memory-efficient neural networks. However, layer factorization methods face several challenges, particularly a lack of robustness during the training process. To overcome this limitation, dynamical low-rank training methods have been developed, utilizing robust time integration techniques for low-rank matrix differential equations. Although these approaches facilitate efficient training, they still depend on computationally intensive QR and singular value decompositions of matrices with small rank. In this work, we introduce a novel low-rank training method that reduces the number of required QR decompositions. Our approach integrates an augmentation step into a projector-splitting scheme, ensuring convergence to a locally optimal solution. We provide a rigorous theoretical analysis of the proposed method and demonstrate its effectiveness across multiple benchmarks.

Paper Structure

This paper contains 23 sections, 7 theorems, 71 equations, 1 figure, 4 tables, 1 algorithm.

Key Result

Lemma 5.1

(Loss evaluation of the PSI) \newlabelremark:decent_PSI0 Let $Y(t)$ be the solution of the PSI evolution equations of eq:dlracont_PSI. Then, the loss is bounded by with

Figures (1)

  • Figure 1: Mean test accuracy of all experimental setups (PSI, bc-PSI, abc-PSI) trained on the MNIST dataset, plotted against their compression rates using learning rates of 0.01 and 0.001. Compression rates correspond to different rank selections (for fixed-rank settings) or varying tolerances (for rank-adaptive settings). Training with a learning rate of 0.01 was unstable for all backward-corrected PSI trainings and original PSI models with ranks $r > 28$, frequently leading to failed trainings; these cases are excluded from the graphic.

Theorems & Definitions (16)

  • Lemma 5.1
  • Proof 1
  • Remark 5.2
  • Theorem 5.3
  • Proof 2
  • Lemma 5.4
  • Proof 3
  • Remark 5.5
  • Theorem 5.6
  • Proof 4
  • ...and 6 more