Table of Contents
Fetching ...

Gradient dynamics for low-rank fine-tuning beyond kernels

Arif Kerem Dayi, Sitan Chen

TL;DR

This work proves under mild assumptions that a student model which is initialized at the base model and trained with online gradient descent will converge to the teacher in $dk^{O(1)}$ iterations, where $k$ is the number of neurons in $f$.

Abstract

LoRA has emerged as one of the de facto methods for fine-tuning foundation models with low computational cost and memory footprint. The idea is to only train a low-rank perturbation to the weights of a pre-trained model, given supervised data for a downstream task. Despite its empirical sucess, from a mathematical perspective it remains poorly understood what learning mechanisms ensure that gradient descent converges to useful low-rank perturbations. In this work we study low-rank fine-tuning in a student-teacher setting. We are given the weights of a two-layer base model $f$, as well as i.i.d. samples $(x,f^*(x))$ where $x$ is Gaussian and $f^*$ is the teacher model given by perturbing the weights of $f$ by a rank-1 matrix. This generalizes the setting of generalized linear model (GLM) regression where the weights of $f$ are zero. When the rank-1 perturbation is comparable in norm to the weight matrix of $f$, the training dynamics are nonlinear. Nevertheless, in this regime we prove under mild assumptions that a student model which is initialized at the base model and trained with online gradient descent will converge to the teacher in $dk^{O(1)}$ iterations, where $k$ is the number of neurons in $f$. Importantly, unlike in the GLM setting, the complexity does not depend on fine-grained properties of the activation's Hermite expansion. We also prove that in our setting, learning the teacher model "from scratch'' can require significantly more iterations.

Gradient dynamics for low-rank fine-tuning beyond kernels

TL;DR

This work proves under mild assumptions that a student model which is initialized at the base model and trained with online gradient descent will converge to the teacher in iterations, where is the number of neurons in .

Abstract

LoRA has emerged as one of the de facto methods for fine-tuning foundation models with low computational cost and memory footprint. The idea is to only train a low-rank perturbation to the weights of a pre-trained model, given supervised data for a downstream task. Despite its empirical sucess, from a mathematical perspective it remains poorly understood what learning mechanisms ensure that gradient descent converges to useful low-rank perturbations. In this work we study low-rank fine-tuning in a student-teacher setting. We are given the weights of a two-layer base model , as well as i.i.d. samples where is Gaussian and is the teacher model given by perturbing the weights of by a rank-1 matrix. This generalizes the setting of generalized linear model (GLM) regression where the weights of are zero. When the rank-1 perturbation is comparable in norm to the weight matrix of , the training dynamics are nonlinear. Nevertheless, in this regime we prove under mild assumptions that a student model which is initialized at the base model and trained with online gradient descent will converge to the teacher in iterations, where is the number of neurons in . Importantly, unlike in the GLM setting, the complexity does not depend on fine-grained properties of the activation's Hermite expansion. We also prove that in our setting, learning the teacher model "from scratch'' can require significantly more iterations.

Paper Structure

This paper contains 53 sections, 28 theorems, 184 equations, 5 figures.

Key Result

Theorem 1

Let $0 < \varepsilon < 1$, and let $\xi\asymp \sqrt{k}$ for sufficiently small absolute constant factor. Suppose the rows of $W$ are orthogonal. Then under Assumptions assumption:normalize-assumption:quantization and for any nice activation $\sigma$ (see Assumption assumption:activation), the follow

Figures (5)

  • Figure 1: Evolution of $\langle u_t, u\rangle$ during online SGD for 10 random instances with joint and frozen-$\hat{c}$ training. Though time scales differ between (a) and (b), trajectories exhibit similar behavior.
  • Figure 2: Linearized networks fail in low-rank fine-tuning, and cannot achieve small loss. When jointly training $\hat{u}$ and $\hat{c}$, we observe incremental behavior in learning, where learning $c$ becomes easier when $u$ is learned to a certain level.
  • Figure 3: Evolution of overlap $\langle u_t, u\rangle$ during online SGD, under scaling (a) $\xi =1$ and (b) $\xi = \Theta(\sqrt{k})$ for different activations. As opposed to learning single-index models, or multi-index models from scratch, the fine-tuning regime is not too sensitive to the choice of activation. Namely, the iteration complexity of fine tuning with SGD does not depend sensitively on information exponent.
  • Figure 4: Evolution of $\langle u_t, u\rangle$ during online SGD for (a) varying levels of $\left\|\Pi_W u\right\|$ (violating the orthogonality assumption) (b) multiple runs for $\left\|\Pi_W u\right\|=\frac{1}{2}$. (a) In certain non-pathological initializations and scales $\xi$, the orthogonal case might capture behavior related to the non-orthogonal case. (b) Over multiple runs with $\left\|\Pi_W u\right\|=\frac{1}{2}$ we see a generic S-curve behavior for ReLU activation.
  • Figure 5: Evolution of $\langle u_t, u\rangle$ during online SGD for varying scales for $\xi$ (a) short timescale, (b) long timescale. Empirically, we see that while for "small" $\xi$ (e.g. $\xi = O(\sqrt{k})$) online SGD can quickly achieve strong recovery, as $\xi \to \infty$ we see that the time scales for both weak and strong recovery get larger. This illustrates how our low-rank fine-tuning setup allows to interpolate between different regimes (e.g. fine tuning to feature learning in single index models).

Theorems & Definitions (76)

  • Theorem 1: Informal, see \ref{['thm:orth_frob_main']}
  • Theorem 2: Informal, see Theorem \ref{['thm:separated_main']}
  • Remark 1: Other algorithms for fine-tuning
  • Theorem 3: Informal, see \ref{['thm:hard-from-scratch']}
  • Remark 2
  • Theorem 4: Orthogonal weights, $\xi=1$
  • Theorem 5: Orthogonal weights, $\xi=\overline{\xi} \sqrt{k}$
  • Theorem 6: Angularly separated weights, $\xi=1$
  • Proposition 1
  • Remark 3: Generalizing GLM regression
  • ...and 66 more