Table of Contents
Fetching ...

Exploring Low Rank Training of Deep Neural Networks

Siddhartha Rao Kamalakara, Acyr Locatelli, Bharat Venkitesh, Jimmy Ba, Yarin Gal, Aidan N. Gomez

TL;DR

This work investigates training deep networks in a low-rank, factorised form to achieve memory and speed benefits. It analyzes spectral initialisation, L2 regularisation on factorised weights, and pre-training through extensive ablations on GPT-2/LM1B and vision models, challenging prior beliefs about why these techniques work. Key findings show that the direction of singular vectors matters more than singular values, and that regularisation impacts the effective rank in nontrivial ways; Frobenius decay can help maintain higher effective rank, while pre-training benefits language tasks but not consistently vision tasks. The study highlights the need for deeper theoretical understanding of training dynamics in factorised networks and suggests promising avenues for future work.

Abstract

Training deep neural networks in low rank, i.e. with factorised layers, is of particular interest to the community: it offers efficiency over unfactorised training in terms of both memory consumption and training time. Prior work has focused on low rank approximations of pre-trained networks and training in low rank space with additional objectives, offering various ad hoc explanations for chosen practice. We analyse techniques that work well in practice, and through extensive ablations on models such as GPT2 we provide evidence falsifying common beliefs in the field, hinting in the process at exciting research opportunities that still need answering.

Exploring Low Rank Training of Deep Neural Networks

TL;DR

This work investigates training deep networks in a low-rank, factorised form to achieve memory and speed benefits. It analyzes spectral initialisation, L2 regularisation on factorised weights, and pre-training through extensive ablations on GPT-2/LM1B and vision models, challenging prior beliefs about why these techniques work. Key findings show that the direction of singular vectors matters more than singular values, and that regularisation impacts the effective rank in nontrivial ways; Frobenius decay can help maintain higher effective rank, while pre-training benefits language tasks but not consistently vision tasks. The study highlights the need for deeper theoretical understanding of training dynamics in factorised networks and suggests promising avenues for future work.

Abstract

Training deep neural networks in low rank, i.e. with factorised layers, is of particular interest to the community: it offers efficiency over unfactorised training in terms of both memory consumption and training time. Prior work has focused on low rank approximations of pre-trained networks and training in low rank space with additional objectives, offering various ad hoc explanations for chosen practice. We analyse techniques that work well in practice, and through extensive ablations on models such as GPT2 we provide evidence falsifying common beliefs in the field, hinting in the process at exciting research opportunities that still need answering.
Paper Structure (20 sections, 7 equations, 5 figures, 9 tables)

This paper contains 20 sections, 7 equations, 5 figures, 9 tables.

Figures (5)

  • Figure 1: TPU Compute hours vs Performance of GPT-2 on LM1B as the model is scaled up. Each point on the line corresponds to a different model size starting from 1024 hidden dimensions (on the top left) to 2560 (in the bottom right) with increments of 256.
  • Figure 2: Comparison of interpolation of low rank and pre-trained networks for ResNet-50 on ImageNet with a rank of 50 %.
  • Figure 3: Total parameters vs Performance of GPT-2 on LM1B as the model is scaled up. Each point on the line corresponds to a different model size starting from 1024 hidden dimensions (on the top left) to 2560 (in the bottom right) with increments of 256.
  • Figure 4: Comparison of interpolation of low rank and pre-trained networks for WideResNet-28 on CIFAR-100 with a rank of 30%.
  • Figure 5: Comparison of interpolation of low rank and pretrained networks for transformer LM.