Table of Contents
Fetching ...

Complexity Control Facilitates Reasoning-Based Compositional Generalization in Transformers

Zhongwang Zhang, Pengxiao Lin, Zhiwei Wang, Yaoyu Zhang, Zhi-Qin John Xu

TL;DR

This study investigates the internal mechanisms underlying Transformers' behavior in compositional tasks and finds that complexity control strategies significantly influence whether the model learns primitive-level rules that generalize out-of-distribution (reasoning-based solutions) or relies solely on memorized map pings (memory-based solutions).

Abstract

Transformers have demonstrated impressive capabilities across various tasks, yet their performance on compositional problems remains a subject of debate. In this study, we investigate the internal mechanisms underlying Transformers' behavior in compositional tasks. We find that complexity control strategies significantly influence whether the model learns primitive-level rules that generalize out-of-distribution (reasoning-based solutions) or relies solely on memorized mappings (memory-based solutions). By applying masking strategies to the model's information circuits and employing multiple complexity metrics, we reveal distinct internal working mechanisms associated with different solution types. Further analysis reveals that reasoning-based solutions exhibit a lower complexity bias, which aligns with the well-studied neuron condensation phenomenon. This lower complexity bias is hypothesized to be the key factor enabling these solutions to learn reasoning rules. We validate these conclusions across multiple real-world datasets, including image generation and natural language processing tasks, confirming the broad applicability of our findings.

Complexity Control Facilitates Reasoning-Based Compositional Generalization in Transformers

TL;DR

This study investigates the internal mechanisms underlying Transformers' behavior in compositional tasks and finds that complexity control strategies significantly influence whether the model learns primitive-level rules that generalize out-of-distribution (reasoning-based solutions) or relies solely on memorized map pings (memory-based solutions).

Abstract

Transformers have demonstrated impressive capabilities across various tasks, yet their performance on compositional problems remains a subject of debate. In this study, we investigate the internal mechanisms underlying Transformers' behavior in compositional tasks. We find that complexity control strategies significantly influence whether the model learns primitive-level rules that generalize out-of-distribution (reasoning-based solutions) or relies solely on memorized mappings (memory-based solutions). By applying masking strategies to the model's information circuits and employing multiple complexity metrics, we reveal distinct internal working mechanisms associated with different solution types. Further analysis reveals that reasoning-based solutions exhibit a lower complexity bias, which aligns with the well-studied neuron condensation phenomenon. This lower complexity bias is hypothesized to be the key factor enabling these solutions to learn reasoning rules. We validate these conclusions across multiple real-world datasets, including image generation and natural language processing tasks, confirming the broad applicability of our findings.
Paper Structure (23 sections, 9 equations, 10 figures)

This paper contains 23 sections, 9 equations, 10 figures.

Figures (10)

  • Figure 1: The schematic diagram of dataset design. The dataset consists of key tokens, integers ranging from 20 to 100, combined with anchor pairs, which represent arithmetic operations. Each data point is formed by pairing a key token with an anchor pair, and the target is the result of applying the anchor pair's operations to the key token. The training set and ID test set uses disjoint combinations of seen anchor pairs, while the OOD test set involves unseen anchor pairs formed from operations seen during training. The model is trained on the training set and evaluated on both ID and OOD test sets to assess generalization.
  • Figure 2: Experimental setup for the compositional task. Left: The single anchors (i.e., $a$, $b$, $c$, $d$) correspond to specific arithmetic operations. Middle: During training, 14 out of the 16 possible anchor pairs are seen in the training set, and the remaining pairs ($c$, $d$), ($d$, $c$) are held out as unseen tasks (does not appear during training). Right: The input sequences comprise an anchor pair, a key token preceding the anchor pair, and noise tokens unrelated to the target. We construct mutually exclusive training and ID test sets using data generated from 14 seen anchor pairs (the specific partitioning method is detailed in Appendix \ref{['data_split']}), while data from the remaining 2 unseen anchor pairs is used to form the OOD test set.
  • Figure 3: (A) ID and OOD generalization of the GPT-2 model on compositional tasks with a fixed weight decay coefficient of 0.01. The abscissa represents the initialization rate $\gamma$, corresponding to the standard deviation $\left(1/{d_{\mathrm{in}}}\right)^{\gamma}$ of the normal distribution used for parameter initialization. The ordinate denotes the accuracy for ID (blue) and OOD (red) data. The different phases are classified based on their ID and OOD generalization abilities. (B) Heatmap illustrating the GPT-2 model's OOD generalization on compositional tasks, with accuracy on unseen anchor pairs depicted by color intensity. The abscissa matches that of Fig. \ref{['fig:phases']}A, while the ordinate represents the weight decay coefficient. Each setting reflects the average results from three independent trials. Striped regions indicate poor ID generalization ($\text{ID accuracy}<90\%$). Green triangles highlight instances of poor commutativity on unseen anchor pairs ($c$, $d$) and ($d$, $c$) ($\text{commutativity probability}<70\%$) when switching anchor pairs.
  • Figure 4: (A, B) Masking strategies applied to the transformer model to analyze the contribution of specific tokens. (A) The key token ($x$) is masked, removing its information from the model's output, allowing analysis of the output space distributions influenced by anchor pairs ($a_1, a_2$) and noise tokens. (B) The second anchor ($a_2$) is masked to isolate the interaction between the key token ($x$) and the first anchor ($a_1$). Dashed gray lines indicate the masked circuits, while solid blue lines represent normal circuits. Hollow nodes depict masked activations, and filled nodes depict normal activations. (C) Principal Component Analysis (PCA) applied to the model outputs to visualize the model’s representation of different anchor pairs. For each anchor pair, 50 data points were sampled. Symmetric anchor pairs (e.g., $(c, d)$ and $(d, c)$) are shown in similar colors with different shades to indicate their equivalence. The three phases are achieved by adjusting the initialization rate ($\gamma$). (D) Cosine similarity matrices between model outputs after masking the second anchor ($a_2$). Each group of four blocks corresponds to inputs with the same $g(x; a_1)$ value. The right panel provides a magnified view of selected blocks, showing detailed values of $x$ and $a_1$ for corresponding inputs. The color scale represents cosine similarity values. Different phases correspond to different values of the initialization rate $\gamma$.
  • Figure 5: (A) Cosine similarity between the input weights of neurons in the first layer’s query weight matrix ($W^{Q(1)}$) for each phase. The abscissa and ordinate both represent the neuron index. The matrices are computed under the settings where the weight decay coefficient is fixed at 0.01, and the initialization rate ($\gamma$) is set to 0.2, 0.5, and 0.8 for Phase 1, Phase 2, and Phase 3, respectively. (B) PCA visualization of the word embedding vectors for the same settings as in (A). Each number corresponds to a specific token, and its position represents the reduced-dimensional embedding obtained through PCA. (C) Bubble plot summarizing the stable rank of parameter matrices ($W^{Q(1)}$) across initialization rates ($\gamma$) and weight decay coefficients. The size of each bubble represents the phase index (w: with commutativity; w/o: without commutativity), with Phase 3 and Phase 2 distinguished based on whether the OOD accuracy exceeds 50%, while other phase boundaries remain consistent with Fig. \ref{['fig:phases']}. The color of each bubble represents the stable rank value, and all results are averaged over three independent trials.
  • ...and 5 more figures