Table of Contents
Fetching ...

Abrupt Learning in Transformers: A Case Study on Matrix Completion

Pulkit Gopalani, Ekdeep Singh Lubana, Wei Hu

TL;DR

The low-rank matrix completion problem is formulated as a masked language modeling (MLM) task, and it is shown that it is possible to train a BERT model to solve this task to low error.

Abstract

Recent analysis on the training dynamics of Transformers has unveiled an interesting characteristic: the training loss plateaus for a significant number of training steps, and then suddenly (and sharply) drops to near--optimal values. To understand this phenomenon in depth, we formulate the low-rank matrix completion problem as a masked language modeling (MLM) task, and show that it is possible to train a BERT model to solve this task to low error. Furthermore, the loss curve shows a plateau early in training followed by a sudden drop to near-optimal values, despite no changes in the training procedure or hyper-parameters. To gain interpretability insights into this sudden drop, we examine the model's predictions, attention heads, and hidden states before and after this transition. Concretely, we observe that (a) the model transitions from simply copying the masked input to accurately predicting the masked entries; (b) the attention heads transition to interpretable patterns relevant to the task; and (c) the embeddings and hidden states encode information relevant to the problem. We also analyze the training dynamics of individual model components to understand the sudden drop in loss.

Abrupt Learning in Transformers: A Case Study on Matrix Completion

TL;DR

The low-rank matrix completion problem is formulated as a masked language modeling (MLM) task, and it is shown that it is possible to train a BERT model to solve this task to low error.

Abstract

Recent analysis on the training dynamics of Transformers has unveiled an interesting characteristic: the training loss plateaus for a significant number of training steps, and then suddenly (and sharply) drops to near--optimal values. To understand this phenomenon in depth, we formulate the low-rank matrix completion problem as a masked language modeling (MLM) task, and show that it is possible to train a BERT model to solve this task to low error. Furthermore, the loss curve shows a plateau early in training followed by a sudden drop to near-optimal values, despite no changes in the training procedure or hyper-parameters. To gain interpretability insights into this sudden drop, we examine the model's predictions, attention heads, and hidden states before and after this transition. Concretely, we observe that (a) the model transitions from simply copying the masked input to accurately predicting the masked entries; (b) the attention heads transition to interpretable patterns relevant to the task; and (c) the embeddings and hidden states encode information relevant to the problem. We also analyze the training dynamics of individual model components to understand the sudden drop in loss.

Paper Structure

This paper contains 49 sections, 7 equations, 27 figures, 1 table.

Figures (27)

  • Figure 1: (A) Matrix completion using BERT. Similar to completing missing words in an English sentence in MLM, we complete missing entries in a masked low--rank matrix. (B) Sudden drop in loss. During training, the model undergoes an algorithmic shift marked by a sharp decrease in mean--squared--error (MSE) loss. Here, the model shifts from simply copying the input (copying phase) to computing missing entries accurately (completion phase).
  • Figure 2: Sharp reduction in training loss.
  • Figure 3: BERT v. Nuclear Norm Minimization. Comparing our model (trained with $p_{\rm mask}=0.3$) and nuclear norm minimization on the matrix completion task at various levels of $p_{\rm mask}.$ The difference in MSE and nuclear norm of solutions obtained using these two approaches indicates that BERT is not implicitly doing nuclear norm minimization to complete missing entries.
  • Figure 4: Attention heads in post--shift model attend to specific positions. For example, (Layer 2, Head 1) attends to elements in the same row as the query element, and (Layer 2, Head 2) attends to elements in the same column as the query element. (These attention matrices are an average over multiple independent matrix and mask samples.)
  • Figure 5: Attention heads with specific mask structure in inputs. We can derive fine-grained insights about the functions of individual heads in this setup by using a specific mask structure for all input matrices. (Mask appended below each plot, blue denotes missing entries). For example, multiple attention heads like (Layer 2, Head 2) have negligible attention weight at missing positions in the input matrix, implying that these heads attend only to observed entries in the column of the query element. Further, (Layer 2 Head 1) and similar heads have larger attention weights for the rows with missing entries, and in those rows they attend to the sole observed element.
  • ...and 22 more figures