Table of Contents
Fetching ...

On the Dynamics & Transferability of Latent Generalization during Memorization

Simran Ketha, Venkatakrishnan Ramaswamy

Abstract

Deep networks have been known to have extraordinary generalization abilities, via mechanisms that aren't yet well understood. It is also known that upon shuffling labels in the training data to varying degrees, deep networks, trained with standard methods, can still achieve perfect or high accuracy on this corrupted training data. This phenomenon is called memorization, and typically comes at the cost of poorer generalization to true labels. Our recent work has demonstrated, that the internal representations of such models retain significantly better latent generalization abilities than is directly apparent from the model. In particular, it has been shown that such latent generalization can be recovered via simple probes (called MASC probes) on the layer-wise representations of the model. However, the origin and dynamics over training of this latent generalization during memorization is not well understood. Here, we track the training dynamics, empirically, and find that latent generalization abilities largely peak early in training, with model generalization. Next, we investigate to what extent the specific nature of the MASC probe is critical for our ability to extract latent generalization from the model's layerwise outputs. To this end, we first examine the mathematical structure of the MASC probe and show that it is a quadratic classifier, i.e. is non-linear. This brings up the question of the extent to which this latent generalization might be linearly decodable from layerwise outputs. To investigate this, we designed a new linear probe for this setting. Next, we consider the question of whether it is possible to transfer latent generalization to model generalization by directly editing model weights. To this end, we devise a way to transfer the latent generalization present in last-layer representations to the model using the new linear probe.

On the Dynamics & Transferability of Latent Generalization during Memorization

Abstract

Deep networks have been known to have extraordinary generalization abilities, via mechanisms that aren't yet well understood. It is also known that upon shuffling labels in the training data to varying degrees, deep networks, trained with standard methods, can still achieve perfect or high accuracy on this corrupted training data. This phenomenon is called memorization, and typically comes at the cost of poorer generalization to true labels. Our recent work has demonstrated, that the internal representations of such models retain significantly better latent generalization abilities than is directly apparent from the model. In particular, it has been shown that such latent generalization can be recovered via simple probes (called MASC probes) on the layer-wise representations of the model. However, the origin and dynamics over training of this latent generalization during memorization is not well understood. Here, we track the training dynamics, empirically, and find that latent generalization abilities largely peak early in training, with model generalization. Next, we investigate to what extent the specific nature of the MASC probe is critical for our ability to extract latent generalization from the model's layerwise outputs. To this end, we first examine the mathematical structure of the MASC probe and show that it is a quadratic classifier, i.e. is non-linear. This brings up the question of the extent to which this latent generalization might be linearly decodable from layerwise outputs. To investigate this, we designed a new linear probe for this setting. Next, we consider the question of whether it is possible to transfer latent generalization to model generalization by directly editing model weights. To this end, we devise a way to transfer the latent generalization present in last-layer representations to the model using the new linear probe.
Paper Structure (24 sections, 2 theorems, 23 equations, 24 figures, 4 algorithms)

This paper contains 24 sections, 2 theorems, 23 equations, 24 figures, 4 algorithms.

Key Result

Proposition 1

MASC is a quadratic classifier.

Figures (24)

  • Figure 1: Minimum Angle Subspace Classifier (MASC) test accuracy over epochs of training for multiple models/datasets, where test data is projected onto class-specific subspaces constructed at each epoch from corrupted training data with the indicated label corruption degree. The plots display MASC accuracy across different layers of the network. For reference, the evolution of test accuracy of the corresponding model (blue dotted line) over epochs of training is also shown. FC denotes fully connected layers with $ReLU$ activation, and Flat refers to the flatten layer without $ReLU$.
  • Figure 2: Vector Linear Probe Intermediate-layer Classifier (VeLPIC) test accuracy during training of the network, where test data is projected onto class vectors constructed at each epoch from training data with the indicated label corruption degrees. The plots display VeLPIC accuracy across different layers of the network for various model–dataset combinations. For reference, the test accuracy of the models (blue dotted line) over epochs of training is also shown. FC denotes fully connected layers with $ReLU$ activation, and Flat refers to the flatten layer without $ReLU$.
  • Figure 3: Difference in test accuracy (VeLPIC Accuracy - MASC Accuracy) during training of the network, where test data is projected onto class vectors constructed at each epoch from training data with the indicated label corruption degrees. The plots display difference in accuracy across different layers of the network for various model–dataset combinations. For reference, the test accuracy of the models (blue dotted line) over epochs of training is also shown, which would be 0.
  • Figure 4: Logistic regression probe's test accuracy over epochs of training for multiple models/datasets. The plots display logistic regression probe's accuracy across different layers of the network. For reference, the evolution of test accuracy of the corresponding model (blue dotted line) over epochs of training is also shown.
  • Figure 5: Difference in test accuracy (MASC Accuracy - Logistic regression probe Accuracy) during training of the network, where for MASC test data is projected onto class vectors constructed at each epoch from training data with the indicated label corruption degrees. The plots display difference in accuracy across different layers of the network for various model–dataset combinations. For reference, the test accuracy of the models (blue dotted line) over epochs of training is also shown, which would be 0.
  • ...and 19 more figures

Theorems & Definitions (4)

  • Proposition 1
  • proof
  • Proposition 2
  • proof