Table of Contents
Fetching ...

Learning on Transformers is Provable Low-Rank and Sparse: A One-layer Analysis

Hongkang Li, Meng Wang, Shuai Zhang, Sijia Liu, Pin-Yu Chen

TL;DR

Problem: Explain why low-rank adaptation and model pruning work for Transformer learning. Approach: Theoretical analysis of a one-layer Transformer trained by SGD on a data model with $M$ orthogonal patterns, where two patterns are discriminative and labels are determined by a majority vote over label-relevant tokens. Key findings: gradient updates to $W_Q$, $W_K$, $W_V$ concentrate along the label-relevant directions, implying a low-rank structure tied to the number of label-relevant patterns; output weights $W_O$ become sparse, enabling effective pruning with bounded impact on generalization; corollaries quantify how pruning affects generalization. Numerical validation on synthetic data confirms the rank-2 effective gradient structure and predictable pruning behavior. Significance: provides a principled explanation for the practical effectiveness of LoRA and magnitude-based pruning in Transformer learning and offers guidance for efficient training.

Abstract

Efficient training and inference algorithms, such as low-rank adaption and model pruning, have shown impressive performance for learning Transformer-based large foundation models. However, due to the technical challenges of the non-convex optimization caused by the complicated architecture of Transformers, the theoretical study of why these methods can be applied to learn Transformers is mostly elusive. To the best of our knowledge, this paper shows the first theoretical analysis of the property of low-rank and sparsity of one-layer Transformers by characterizing the trained model after convergence using stochastic gradient descent. By focusing on a data model based on label-relevant and label-irrelevant patterns, we quantify that the gradient updates of trainable parameters are low-rank, which depends on the number of label-relevant patterns. We also analyze how model pruning affects the generalization while improving computation efficiency and conclude that proper magnitude-based pruning has a slight effect on the testing performance. We implement numerical experiments to support our findings.

Learning on Transformers is Provable Low-Rank and Sparse: A One-layer Analysis

TL;DR

Problem: Explain why low-rank adaptation and model pruning work for Transformer learning. Approach: Theoretical analysis of a one-layer Transformer trained by SGD on a data model with orthogonal patterns, where two patterns are discriminative and labels are determined by a majority vote over label-relevant tokens. Key findings: gradient updates to , , concentrate along the label-relevant directions, implying a low-rank structure tied to the number of label-relevant patterns; output weights become sparse, enabling effective pruning with bounded impact on generalization; corollaries quantify how pruning affects generalization. Numerical validation on synthetic data confirms the rank-2 effective gradient structure and predictable pruning behavior. Significance: provides a principled explanation for the practical effectiveness of LoRA and magnitude-based pruning in Transformer learning and offers guidance for efficient training.

Abstract

Efficient training and inference algorithms, such as low-rank adaption and model pruning, have shown impressive performance for learning Transformer-based large foundation models. However, due to the technical challenges of the non-convex optimization caused by the complicated architecture of Transformers, the theoretical study of why these methods can be applied to learn Transformers is mostly elusive. To the best of our knowledge, this paper shows the first theoretical analysis of the property of low-rank and sparsity of one-layer Transformers by characterizing the trained model after convergence using stochastic gradient descent. By focusing on a data model based on label-relevant and label-irrelevant patterns, we quantify that the gradient updates of trainable parameters are low-rank, which depends on the number of label-relevant patterns. We also analyze how model pruning affects the generalization while improving computation efficiency and conclude that proper magnitude-based pruning has a slight effect on the testing performance. We implement numerical experiments to support our findings.

Paper Structure

This paper contains 9 sections, 2 theorems, 8 equations, 3 figures, 1 table.

Key Result

Theorem 1

(low rank and sparsity) Suppose all conditions for a zero generalization $f(\Psi)=0$ in Theorem 1 in LWLC23 holds. Denote $\Delta{\boldsymbol W}={\boldsymbol W}^{(T)}-{\boldsymbol W}^{(0)}$ as the gradient update of ${\boldsymbol W}$ after convergence. Let $\alpha_*$ be the average fraction of label

Figures (3)

  • Figure 1: When all trainable parameters are approximated with the rank equal to $1$, $2$, $5$, $10$, and $20$ (full rank), the results of (A) testing hinge loss, (B) the attention weights summation on label-relevant tokens.
  • Figure 2: The singular value of $\Delta{\boldsymbol W}_K^{(t)}$ when (A) $t=1$, (B) $t=10$, (C) $t=20$, (D) $t=30$.
  • Figure 3: (A) The magnitude of the trained ${\boldsymbol W}_O$ neurons (B) The testing performance of magnitude-based pruning with different pruning rate.

Theorems & Definitions (3)

  • Theorem 1
  • Remark 1
  • Corollary 1