Table of Contents
Fetching ...

Transformer-based Causal Language Models Perform Clustering

Xinbo Wu, Lav R. Varshney

TL;DR

A simplified instruction-following task is introduced and synthetic datasets are used to analyze a Transformer-based causal language model to suggest that the model learns task-specific information by clustering data within its hidden space, with this clustering process evolving dynamically during learning.

Abstract

Even though large language models (LLMs) have demonstrated remarkable capability in solving various natural language tasks, the capability of an LLM to follow human instructions is still a concern. Recent works have shown great improvements in the instruction-following capability via additional training for instruction-following tasks. However, the mechanisms responsible for effective instruction-following capabilities remain inadequately understood. Here, we introduce a simplified instruction-following task and use synthetic datasets to analyze a Transformer-based causal language model. Our findings suggest that the model learns task-specific information by clustering data within its hidden space, with this clustering process evolving dynamically during learning. We also demonstrate how this phenomenon assists the model in handling unseen instances, and validate our results in a more realistic setting. Furthermore, we present inspired applications regarding pre-training and alignment.

Transformer-based Causal Language Models Perform Clustering

TL;DR

A simplified instruction-following task is introduced and synthetic datasets are used to analyze a Transformer-based causal language model to suggest that the model learns task-specific information by clustering data within its hidden space, with this clustering process evolving dynamically during learning.

Abstract

Even though large language models (LLMs) have demonstrated remarkable capability in solving various natural language tasks, the capability of an LLM to follow human instructions is still a concern. Recent works have shown great improvements in the instruction-following capability via additional training for instruction-following tasks. However, the mechanisms responsible for effective instruction-following capabilities remain inadequately understood. Here, we introduce a simplified instruction-following task and use synthetic datasets to analyze a Transformer-based causal language model. Our findings suggest that the model learns task-specific information by clustering data within its hidden space, with this clustering process evolving dynamically during learning. We also demonstrate how this phenomenon assists the model in handling unseen instances, and validate our results in a more realistic setting. Furthermore, we present inspired applications regarding pre-training and alignment.
Paper Structure (26 sections, 10 figures, 9 tables)

This paper contains 26 sections, 10 figures, 9 tables.

Figures (10)

  • Figure 1: A synthetic dataset for our simplified instruction-following setting. The first task colored green is shown as an example with instructions sampled from $M$ different distributions colored pink. The task function consists of five mappings, in which "->" means from an input to an output. There are several instructions sampled via a regular expression under each distribution.
  • Figure 2: Clustering analysis on both of training subset (a) and validation set (b) across different layers throughout the training process: Different columns corresponds to uses of different identities as labels. Only shows results on F1 score here and see results on other evaluation metrics in Figure \ref{['fig:learning_dynamics_all']}. Each dot represents a data point.
  • Figure 3: (a) Training loss, (b) Training subset task accuracy, and (c) validation task accuracy throughout the training process. Each dot represents a data point. Both (b) and (c) show dense dots with near-zero accuracy for the first few epochs.
  • Figure 4: (a) Percentage of K nearest neighbors in the training set of an unseen instance belonging to the same task identity. (b) K nearest neighbors accuracy. Measurements are performed across all layers and throughout the training process.
  • Figure 5: (a) Clustering analysis of the model trained for the task ID prediction following the setting in Section \ref{['sec:clustering_analysis']}. (b) Comparison of different pre-training strategies by their performances during the fine-tuning process. The task accuracy is measured on the validation set.
  • ...and 5 more figures