Table of Contents
Fetching ...

QUEST: A robust attention formulation using query-modulated spherical attention

Hariprasath Govindarajan, Per Sidén, Jacob Roll, Fredrik Lindsten

Abstract

The Transformer model architecture has become one of the most widely used in deep learning and the attention mechanism is at its core. The standard attention formulation uses a softmax operation applied to a scaled dot product between query and key vectors. We explore the role played by norms of the queries and keys, which can cause training instabilities when they arbitrarily increase. We demonstrate how this can happen even in simple Transformer models, in the presence of easy-to-learn spurious patterns in the data. We propose a new attention formulation, QUEry-modulated Spherical aTtention (QUEST), that constrains the keys to a hyperspherical latent space, while still allowing individual tokens to flexibly control the sharpness of the attention distribution. QUEST can be easily used as a drop-in replacement for standard attention. We focus on vision applications while also exploring other domains to highlight the method's generality. We show that (1) QUEST trains without instabilities and (2) produces models with improved performance (3) that are robust to data corruptions and adversarial attacks.

QUEST: A robust attention formulation using query-modulated spherical attention

Abstract

The Transformer model architecture has become one of the most widely used in deep learning and the attention mechanism is at its core. The standard attention formulation uses a softmax operation applied to a scaled dot product between query and key vectors. We explore the role played by norms of the queries and keys, which can cause training instabilities when they arbitrarily increase. We demonstrate how this can happen even in simple Transformer models, in the presence of easy-to-learn spurious patterns in the data. We propose a new attention formulation, QUEry-modulated Spherical aTtention (QUEST), that constrains the keys to a hyperspherical latent space, while still allowing individual tokens to flexibly control the sharpness of the attention distribution. QUEST can be easily used as a drop-in replacement for standard attention. We focus on vision applications while also exploring other domains to highlight the method's generality. We show that (1) QUEST trains without instabilities and (2) produces models with improved performance (3) that are robust to data corruptions and adversarial attacks.

Paper Structure

This paper contains 38 sections, 6 equations, 13 figures, 24 tables, 2 algorithms.

Figures (13)

  • Figure 1: Class-activation maps for an image from the Macaw class in ImageNet, generated using AG-CAM agcam. Standard attention concentrates on few bird instances (see first row) and mis-classifies the image if the region containing those instances is noised (see third row). This indicates that the birds in the bottom half of the image do not contribute to the correct prediction in standard attention. Hence, when the top part of the image is noised, the model focuses on the birds in the bottom part of the image since they are the most salient object in the image then, but results in a misclassification. QUEST attention attends evenly to different bird instances and classifies the image correctly even if some of the bird instances are noised. A more diverse attention can make the models more robust to input data variations, which can be observed in the improved model robustness in \ref{['sec: expt_image_robustness']}.
  • Figure 2: Illustration of toy example.
  • Figure 3: Success rates of learning the correct solution to the toy example. Models are trained with different hyperparameter combinations with 5 different weight initializations and 5 different realizations of the data. The QKNorm methods obtained $\sim$0% overall success rate and their results are available in Figure \ref{['fig:toy_example_simple_test_success_rates_2x3']}.
  • Figure 4: Norms of answer key tokens for biased and unbiased samples: A common failure case for standard and QNorm attention involves the key norms of the biased answer token increasing as the training progresses. The model relies on looking up the bias vector to identify the answer.
  • Figure 5: Class activation maps for Elephant and Zebra. Model with QUEST attention shows better coverage of the different instances of the animals than standard attention.
  • ...and 8 more figures