Table of Contents
Fetching ...

Linguistic Collapse: Neural Collapse in (Large) Language Models

Robert Wu, Vardan Papyan

TL;DR

This paper empirically investigates the impact of scaling the architectures and training of causal language models (CLMs) on their progression towards $\mathcal{NC}$ and finds that properties that develop with scale (and regularization) are linked to generalization.

Abstract

Neural collapse ($\mathcal{NC}$) is a phenomenon observed in classification tasks where top-layer representations collapse into their class means, which become equinorm, equiangular and aligned with the classifiers. These behaviours -- associated with generalization and robustness -- would manifest under specific conditions: models are trained towards zero loss, with noise-free labels belonging to balanced classes, which do not outnumber the model's hidden dimension. Recent studies have explored $\mathcal{NC}$ in the absence of one or more of these conditions to extend and capitalize on the associated benefits of ideal geometries. Language modelling presents a curious frontier, as \textit{training by token prediction} constitutes a classification task where none of the conditions exist: the vocabulary is imbalanced and exceeds the embedding dimension; different tokens might correspond to similar contextual embeddings; and large language models (LLMs) in particular are typically only trained for a few epochs. This paper empirically investigates the impact of scaling the architectures and training of causal language models (CLMs) on their progression towards $\mathcal{NC}$. We find that $\mathcal{NC}$ properties that develop with scale (and regularization) are linked to generalization. Moreover, there is evidence of some relationship between $\mathcal{NC}$ and generalization independent of scale. Our work thereby underscores the generality of $\mathcal{NC}$ as it extends to the novel and more challenging setting of language modelling. Downstream, we seek to inspire further research on the phenomenon to deepen our understanding of LLMs -- and neural networks at large -- and improve existing architectures based on $\mathcal{NC}$-related properties. Our code is hosted on GitHub at https://github.com/rhubarbwu/linguistic-collapse .

Linguistic Collapse: Neural Collapse in (Large) Language Models

TL;DR

This paper empirically investigates the impact of scaling the architectures and training of causal language models (CLMs) on their progression towards and finds that properties that develop with scale (and regularization) are linked to generalization.

Abstract

Neural collapse () is a phenomenon observed in classification tasks where top-layer representations collapse into their class means, which become equinorm, equiangular and aligned with the classifiers. These behaviours -- associated with generalization and robustness -- would manifest under specific conditions: models are trained towards zero loss, with noise-free labels belonging to balanced classes, which do not outnumber the model's hidden dimension. Recent studies have explored in the absence of one or more of these conditions to extend and capitalize on the associated benefits of ideal geometries. Language modelling presents a curious frontier, as \textit{training by token prediction} constitutes a classification task where none of the conditions exist: the vocabulary is imbalanced and exceeds the embedding dimension; different tokens might correspond to similar contextual embeddings; and large language models (LLMs) in particular are typically only trained for a few epochs. This paper empirically investigates the impact of scaling the architectures and training of causal language models (CLMs) on their progression towards . We find that properties that develop with scale (and regularization) are linked to generalization. Moreover, there is evidence of some relationship between and generalization independent of scale. Our work thereby underscores the generality of as it extends to the novel and more challenging setting of language modelling. Downstream, we seek to inspire further research on the phenomenon to deepen our understanding of LLMs -- and neural networks at large -- and improve existing architectures based on -related properties. Our code is hosted on GitHub at https://github.com/rhubarbwu/linguistic-collapse .
Paper Structure (67 sections, 12 equations, 30 figures, 5 tables)

This paper contains 67 sections, 12 equations, 30 figures, 5 tables.

Figures (30)

  • Figure 1: Simultaneous development of the four neural collapse ($\mathcal{NC}$) papyan2020prevalence properties in 230 causal language models trained on TinyStories eldan2023tinystories, alongside improvement in generalization (i.e. validation performance). Left to right: $\mathcal{NC}_1$) within-class (representation) variability collapse; $\mathcal{GNC}_2$) hyperspherical uniformity of class means; $\mathcal{UNC}_3$) uniform duality between class means and corresponding classifiers; and $\mathcal{NC}_4$) agreement between token (maximum a prior) classifiers and implicit nearest-class centre classifiers. Coloured by model size and annotated with coefficient of determination ($R^2$).
  • Figure 2: Validation loss is correlated with all three measurements: (left) equinormness ($\mathcal{NC}2$) expressed as variation in logarithmic norms; (centre) equiangularity ($\mathcal{NC}2$) as variation in interference; (right) hyperspherical uniformity ($\mathcal{GNC}2$) as variation in logarithmic pairwise distances.
  • Figure 3: Validation loss shows a negligible relationship with self-duality ($\mathcal{NC}_3$, left) and some correlation with uniform duality ($\mathcal{UNC}_3$, right). In other words, $\mathcal{UNC}_3$ develops with scale and correlates with generalization much better than $\mathcal{NC}_3$.
  • Figure 4: The 500.0 most frequent classes from TinyStories eldan2023tinystories exhibit significant sample imbalance. Despite the synthetic nature of TinyStories, such a distribution is typical of natural language shannon1948mathematicalflorence1950human.
  • Figure 5: Average (logarithmic) class-distance normalized variance (CDNV, $\mathcal{NC}_1$) (left) and validation (cross-entropy) loss (right) with respect to training epochs.
  • ...and 25 more figures