Table of Contents
Fetching ...

Attention as a Hypernetwork

Simon Schug, Seijin Kobayashi, Yassir Akram, João Sacramento, Razvan Pascanu

TL;DR

This work reframes multi-head attention as a hypernetwork that uses a low-dimensional latent code, derived from head-wise attention, to configure a per-key-query value network. It shows that scaling model size and data facilitates compositional generalization on abstract reasoning tasks and yields a structured latent space predictive of the network's function. By introducing HYLA, which adds a nonlinear value network and head-wise normalization, the authors demonstrate improved compositional generalization on challenging tasks, including a symbolic Raven-like benchmark (sraven). The results suggest the hypernetwork mechanism in attention underpins substantial aspects of in-context learning and compositionality, with practical implications for understanding and improving large-scale transformer models.

Abstract

Transformers can under some circumstances generalize to novel problem instances whose constituent parts might have been encountered during training, but whose compositions have not. What mechanisms underlie this ability for compositional generalization? By reformulating multi-head attention as a hypernetwork, we reveal that a composable, low-dimensional latent code specifies key-query specific operations. We find empirically that this latent code is predictive of the subtasks the network performs on unseen task compositions, revealing that latent codes acquired during training are reused to solve unseen problem instances. To further examine the hypothesis that the intrinsic hypernetwork of multi-head attention supports compositional generalization, we ablate whether making the hypernetwork-generated linear value network nonlinear strengthens compositionality. We find that this modification improves compositional generalization on abstract reasoning tasks. In particular, we introduce a symbolic version of the Raven's Progressive Matrices human intelligence test, which gives us precise control over the problem compositions encountered during training and evaluation. We demonstrate on this task how scaling model size and data enables compositional generalization in transformers and gives rise to a functionally structured latent space.

Attention as a Hypernetwork

TL;DR

This work reframes multi-head attention as a hypernetwork that uses a low-dimensional latent code, derived from head-wise attention, to configure a per-key-query value network. It shows that scaling model size and data facilitates compositional generalization on abstract reasoning tasks and yields a structured latent space predictive of the network's function. By introducing HYLA, which adds a nonlinear value network and head-wise normalization, the authors demonstrate improved compositional generalization on challenging tasks, including a symbolic Raven-like benchmark (sraven). The results suggest the hypernetwork mechanism in attention underpins substantial aspects of in-context learning and compositionality, with practical implications for understanding and improving large-scale transformer models.

Abstract

Transformers can under some circumstances generalize to novel problem instances whose constituent parts might have been encountered during training, but whose compositions have not. What mechanisms underlie this ability for compositional generalization? By reformulating multi-head attention as a hypernetwork, we reveal that a composable, low-dimensional latent code specifies key-query specific operations. We find empirically that this latent code is predictive of the subtasks the network performs on unseen task compositions, revealing that latent codes acquired during training are reused to solve unseen problem instances. To further examine the hypothesis that the intrinsic hypernetwork of multi-head attention supports compositional generalization, we ablate whether making the hypernetwork-generated linear value network nonlinear strengthens compositionality. We find that this modification improves compositional generalization on abstract reasoning tasks. In particular, we introduce a symbolic version of the Raven's Progressive Matrices human intelligence test, which gives us precise control over the problem compositions encountered during training and evaluation. We demonstrate on this task how scaling model size and data enables compositional generalization in transformers and gives rise to a functionally structured latent space.
Paper Structure (41 sections, 8 equations, 15 figures, 3 tables, 2 algorithms)

This paper contains 41 sections, 8 equations, 15 figures, 3 tables, 2 algorithms.

Figures (15)

  • Figure 1: Hypernetwork attention. A A linear hypernetwork maps a latent code to a set of parameters that configure a value network to process the input. B The attention scores along the head index form the latent code of the hypernetwork. C Multi-head attention can be equivalently expressed as a linear hypernetwork that configures key-query specific computations of a linear value network.
  • Figure 2: Compositional generalization on fuzzy logic functions. A We split fuzzy logic functions according to their constituent terms into train and out-of-distribution (OOD) sets to measure compositional generalization in a sequence model that learns these functions in-context. B The latent code of the response token is predictive of the constituent terms underlying each task. Shown is the F1 score on unseen tasks of logistic regression classifiers for each layer and term, trained to predict the terms underlying each task based on the attention scores across the head index for the response token attending to itself. C Task performance on unseen tasks reported as OOD $R^2$ for varying number of in-context examples and fraction of tasks held-out during training. D tSNE visualization of the latent codes (attention scores across the head index for the response token attending to itself) colored according to the target label (top) and colored according to the first term of the fuzzy logic function of each task (bottom).
  • Figure 3: sraven. A Illustration of sraven task generation and the construction of the compositional generalization split. B Example problem instance illustrating a key challenge of the original Raven's Progressive Matrices of finding correspondences (adapted from carpenter_what_1990). When attempting to solve this instance, different hypotheses over which figural elements are governed by a consistent rule across rows are possible. This is akin to different orderings of the symbolic features.
  • Figure 4: Scaling data and model size on sraven. A Compositional generalization measured by OOD accuracy as a function of the number of training problem instances for different widths (scaling embedding and key-query-value dimensions). B Same as A but increasing depth instead of width.
  • Figure 5: Latent code structure of sraven. A tSNE visualizations of the final layer latent codes for the final response token colored by the magnitude of the predicted target value. B Same as A but colored by the ground-truth rule the model needs to apply to generate the correct prediction. C OOD accuracy for varying numbers of attention heads. For a single head, the hypernetwork mechanism is absent, which hampers OOD generalization. D The difficulty of sraven can be parametrically controlled by varying, $K$, the number of features per panel . E Heatmap showing the pairwise cosine similarity between the average latent code of each rule for HYLA revealing how semantically related rules form clusters. For instance, rule F (addition) and G (difference) are implemented with a very similar latent code, indicating that the same code might be reused by flipping the sign of the operands. F Decoding performance of a logistic regression classifier trained to predict the ground-truth sraven rule based on the latent code at the final response token of training tasks and evaluated on unseen OOD tasks, revealing that the latent code is predictive of the implemented rule.
  • ...and 10 more figures