Table of Contents
Fetching ...

Weight decay induces low-rank attention layers

Seijin Kobayashi, Yassir Akram, Johannes Von Oswald

TL;DR

It is verified empirically that the key-query and value-projection matrix products within attention layers, when optimized with weight decay, as usually done in vision tasks and language modelling, indeed induce a significant reduction in the rank of W_K^TW_Q and PW_V, even in fully online training.

Abstract

The effect of regularizers such as weight decay when training deep neural networks is not well understood. We study the influence of weight decay as well as $L2$-regularization when training neural network models in which parameter matrices interact multiplicatively. This combination is of particular interest as this parametrization is common in attention layers, the workhorse of transformers. Here, key-query, as well as value-projection parameter matrices, are multiplied directly with each other: $W_K^TW_Q$ and $PW_V$. We extend previous results and show on one hand that any local minimum of a $L2$-regularized loss of the form $L(AB^\top) + λ(\|A\|^2 + \|B\|^2)$ coincides with a minimum of the nuclear norm-regularized loss $L(AB^\top) + λ\|AB^\top\|_*$, and on the other hand that the 2 losses become identical exponentially quickly during training. We thus complement existing works linking $L2$-regularization with low-rank regularization, and in particular, explain why such regularization on the matrix product affects early stages of training. Based on these theoretical insights, we verify empirically that the key-query and value-projection matrix products $W_K^TW_Q, PW_V$ within attention layers, when optimized with weight decay, as usually done in vision tasks and language modelling, indeed induce a significant reduction in the rank of $W_K^TW_Q$ and $PW_V$, even in fully online training. We find that, in accordance with existing work, inducing low rank in attention matrix products can damage language model performance, and observe advantages when decoupling weight decay in attention layers from the rest of the parameters.

Weight decay induces low-rank attention layers

TL;DR

It is verified empirically that the key-query and value-projection matrix products within attention layers, when optimized with weight decay, as usually done in vision tasks and language modelling, indeed induce a significant reduction in the rank of W_K^TW_Q and PW_V, even in fully online training.

Abstract

The effect of regularizers such as weight decay when training deep neural networks is not well understood. We study the influence of weight decay as well as -regularization when training neural network models in which parameter matrices interact multiplicatively. This combination is of particular interest as this parametrization is common in attention layers, the workhorse of transformers. Here, key-query, as well as value-projection parameter matrices, are multiplied directly with each other: and . We extend previous results and show on one hand that any local minimum of a -regularized loss of the form coincides with a minimum of the nuclear norm-regularized loss , and on the other hand that the 2 losses become identical exponentially quickly during training. We thus complement existing works linking -regularization with low-rank regularization, and in particular, explain why such regularization on the matrix product affects early stages of training. Based on these theoretical insights, we verify empirically that the key-query and value-projection matrix products within attention layers, when optimized with weight decay, as usually done in vision tasks and language modelling, indeed induce a significant reduction in the rank of and , even in fully online training. We find that, in accordance with existing work, inducing low rank in attention matrix products can damage language model performance, and observe advantages when decoupling weight decay in attention layers from the rest of the parameters.

Paper Structure

This paper contains 37 sections, 9 theorems, 62 equations, 7 figures, 2 tables.

Key Result

Proposition 3.1

Let $A,B$ be matrices such that $A^\top A=B^\top B$. Then, denoting $AB^\top = U S V^\top$ the SVD of $AB^\top$, there exist an orthogonal matrix $O$ such that $A = U \left(\right) O^\top$ and $B = V \left(\right) O^{\top}$. In particular, $\|AB^\top\|_*=\frac{1}{2}(\|A\|^2 + \|B\|^2)$.

Figures (7)

  • Figure 1: Optimization by gradient descent of two $5$-by-$5$ matrices $A,B$ on the $L2$-regularized loss $\|AB^\top - D\|^2 + \frac{\lambda}{2} (\|A\|^2 +\|B\|^2)$ where $D=\text{diag}(0.2,0.4,0.6,0.8,1)$, with various regularization strength $\lambda$. $t$ denotes the number of optimization steps. Left: difference between the nuclear norm $\|AB^\top\|_*$ with the Frobenius norm $\frac{1}{2}\|A\|^2+\frac{1}{2}\|B\|^2$ throughout optimization. For all cases, other than $\lambda=0$, the trajectory converges exponentially quickly to $0$ as predicted by our theory. Center left: Norm of the discrepancy between $A^\top A$ and $B^\top B$ over training steps. As predicted the discrepancy exponentially vanishes, with a time constant proportional to the $\lambda$. Center right: Singular values of the matrix $AB^\top$ at $t=1000$, for various regularization strength $\lambda$. As predicted, $s_i$ decays linearly with $\lambda$, until $\lambda \geq s_i$, at which point the singular value vanishes. Right: Singular values of the matrix $AB^\top$ during optimization, for $\lambda = 0.4$.
  • Figure 2: Left: The rank of weight matrix product $PW_V$ of the first layer of a 2-layer Transformer trained on the associative recall task, during training, with AdamW, for various decay strengths. To better account for the effect of weight decay on the attention layers, only the decay strength applied to attention layers is varied, while the strength for all other layers is fixed at $0.1$. We observe that rank reduction correlates strongly with weight decay strength. Center: Norm of the discrepancy between $P^\top P$ and $W_VW_V^\top$, during training. As predicted, the difference seems to converge to $0$ when $\lambda >0$ towards the end of training. While for AdamW we no longer have the guarantee of an exponential decay, we see that the discrepancy nonetheless vanishes quickly, with a time constant which perfectly correlates with the decay strength. Right: The difference of the nuclear norm of $W_{VP}$ with the Frobenius norm upper bounding it. As the discrepancy between $P^\top P$ and $W_VW_V^\top$ decreases, the difference approaches $0$, and thus the bound becomes tight. The optimization of $\mathcal{L}_{L2}$ thus gradually switches to that of $\mathcal{L}_*$, explaining the rank regularization. Qualitative findings are identical when studying $W_K^\top W_Q$.
  • Figure 3: Left, center left: The rank of weight matrix products $W_K^\top W_Q$ and $PW_V$ averaged across heads of layer 5 of an autoregressive transformers trained on the Pile gao_pile_2020. Center right, right: The rank of weight matrix products $W_K^\top W_Q$ and $PW_V$ averaged over all heads and all layers of a Vision Transformer trained following irandoust2022training on the ImageNet dataset deng2009imagenet. In both settings, the decay strength applied to attention layers is varied, while keeping the strength for all other layers fixed. In all cases, we observe again that rank reduction correlates strongly with weight decay strength when optimizing with AdamW. The weight decay strength of $0.1$ commonly used to pretrain some known large foundation models in fact noticeably reduces the rank of the generated matrices compared to when weight decay is turned off.
  • Figure 4: Analyses of attention layers in the pretrained LLAMA 2 model with 7 Billion parameters touvron2023llama. The leftmost (resp. center left) shows the squared norm of every row of $W_Q$ (resp. $W_V$), for the first head of each layer, against the norm of the corresponding row of $W_K$ (resp. column of $P$). The condition $W_KW_K^\top = W_QW_Q^\top$ would require these norms to be equal, which in fact is mostly true. While the model has not reached a stationary point, this indicates the optimization has advanced enough for this sufficient condition for $\mathcal{L}_*$ to be identical to $\mathcal{L}_{L2}$ to emerge. In fact, the center right (resp. rightmost) plot show the scatter plot mapping the Frobenius norm against the nuclear norm for all heads across all layers. The two norms almost perfectly coincide.
  • Figure 5: Trajectory of $w_1, w_2$ in the 2D plane when optimizing the underlying parameter for various hyperparameters. At every coordinate in the plane, the loss is defined as the squared distance to the surface $\mathcal{S}$ in orange. The red (resp. blue) cross represents the points on $\mathcal{S}$ minimizing the $L2$-norm (resp. $L1$-norm). Left: $w_1, w_2$ are directly parametrized and optimized by AdamW with decoupled weight decay (in solid line) or Adam with $L2$-regularization (in dotted line). As conjectured, the convergence point of AdamW given the hyperparameter $\epsilon$ and decay strength $\lambda_wd$ corresponds to that of the equilibrium point of the $L2$-regularized loss with regularization strength $\lambda_{L2} = \lambda_wd \epsilon$. Right: $w_1, w_2$ are parameterized as a product of two scalars, i.e. $w_1=a_1 b_1, w_2=a_2 b_2$, where $a_1, b_1,a_2, b_2$ are now optimized by AdamW or Adam with $L2$ regularization. Again, the two optimizers find the same convergence point for equivalent hyperparameters. However, the solution found now corresponds to those of the loss regularized by the $L1$-norm of $w_1,w_2$, (corresponding to the nuclear norm for scalars) as predicted.
  • ...and 2 more figures

Theorems & Definitions (16)

  • Proposition 3.1
  • Lemma 3.2
  • Theorem 3.3
  • Theorem 3.4
  • proof
  • proof
  • proof
  • Lemma B.1
  • proof
  • Proposition B.2
  • ...and 6 more