Table of Contents
Fetching ...

Towards understanding how attention mechanism works in deep learning

Tianyu Ruan, Shihua Zhang

TL;DR

This study inspects the process of computing similarity using classic metrics and vector space properties in manifold learning, clustering, and supervised learning, and demonstrates that the self-attention mechanism in deep learning adheres to the same principles but operates more flexibly and adaptively.

Abstract

Attention mechanism has been extensively integrated within mainstream neural network architectures, such as Transformers and graph attention networks. Yet, its underlying working principles remain somewhat elusive. What is its essence? Are there any connections between it and traditional machine learning algorithms? In this study, we inspect the process of computing similarity using classic metrics and vector space properties in manifold learning, clustering, and supervised learning. We identify the key characteristics of similarity computation and information propagation in these methods and demonstrate that the self-attention mechanism in deep learning adheres to the same principles but operates more flexibly and adaptively. We decompose the self-attention mechanism into a learnable pseudo-metric function and an information propagation process based on similarity computation. We prove that the self-attention mechanism converges to a drift-diffusion process through continuous modeling provided the pseudo-metric is a transformation of a metric and certain reasonable assumptions hold. This equation could be transformed into a heat equation under a new metric. In addition, we give a first-order analysis of attention mechanism with a general pseudo-metric function. This study aids in understanding the effects and principle of attention mechanism through physical intuition. Finally, we propose a modified attention mechanism called metric-attention by leveraging the concept of metric learning to facilitate the ability to learn desired metrics more effectively. Experimental results demonstrate that it outperforms self-attention regarding training efficiency, accuracy, and robustness.

Towards understanding how attention mechanism works in deep learning

TL;DR

This study inspects the process of computing similarity using classic metrics and vector space properties in manifold learning, clustering, and supervised learning, and demonstrates that the self-attention mechanism in deep learning adheres to the same principles but operates more flexibly and adaptively.

Abstract

Attention mechanism has been extensively integrated within mainstream neural network architectures, such as Transformers and graph attention networks. Yet, its underlying working principles remain somewhat elusive. What is its essence? Are there any connections between it and traditional machine learning algorithms? In this study, we inspect the process of computing similarity using classic metrics and vector space properties in manifold learning, clustering, and supervised learning. We identify the key characteristics of similarity computation and information propagation in these methods and demonstrate that the self-attention mechanism in deep learning adheres to the same principles but operates more flexibly and adaptively. We decompose the self-attention mechanism into a learnable pseudo-metric function and an information propagation process based on similarity computation. We prove that the self-attention mechanism converges to a drift-diffusion process through continuous modeling provided the pseudo-metric is a transformation of a metric and certain reasonable assumptions hold. This equation could be transformed into a heat equation under a new metric. In addition, we give a first-order analysis of attention mechanism with a general pseudo-metric function. This study aids in understanding the effects and principle of attention mechanism through physical intuition. Finally, we propose a modified attention mechanism called metric-attention by leveraging the concept of metric learning to facilitate the ability to learn desired metrics more effectively. Experimental results demonstrate that it outperforms self-attention regarding training efficiency, accuracy, and robustness.

Paper Structure

This paper contains 47 sections, 9 theorems, 93 equations, 6 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1

Suppose $\{x_i\}_{i=1}^N$ are i.i.d. sampled from the uniform measure on a compact Riemannian manifold, then where $f(\cdot)$ is a smooth function, $\Delta$ is the (negative defined) Laplacian-Beltrami operator of the Riemannian manifold, and $\mathcal{O}$ represents the big O notation.

Figures (6)

  • Figure 1: Illustration of three architecture components including the residual block (A), the attention block (C), and their recombination (B) of a Transformer block.
  • Figure 2: Illustration of the main idea. (a) Attention mechanism consists of two main steps: (1) computing the similarity between nodes (data points or tokens), followed by propagating node features to neighboring nodes, weighted by the similarities, and (2) updating the features of the nodes. (b) Illustration of a drift-diffusion process on the manifold where data reside. This process is driven by two main forces: density guidance, which encourages local concentration, and diffusion, which promotes global consistency of features. This study demonstrates that attention mechanism can be considered a first-order approximation of the drift-diffusion process on manifold, i.e., the short-time diffusion.
  • Figure 3: Experimental results on the MNIST dataset. The left subplot shows the loss curve during the training process. The middle subplot shows the testing accuracy during training. The right subplot is a violin plot of the test accuracy of three structures at the end of training.
  • Figure 4: Experimental results on the Moon dataset. (top-left) Visualization of this dataset. (top-right) the testing accuracy of the three methods at the end of training. (bottom-left and bottom-right) illustration of the testing accuracy and loss of the three methods during the training process, respectively.
  • Figure 5: Evaluation of self-attention, L2 self-attention and metric-attention methods on the Human semantic segmentation dataset. (top-left) Visualization of four instances in the dataset. (top-right) the testing accuracy of the three methods at the end of training. (bottom-left and bottom-right) illustration of the testing accuracy and loss of the three methods during the training process, respectively.
  • ...and 1 more figures

Theorems & Definitions (22)

  • Definition 1: Metric-based similarity generation
  • Definition 2: $QK$-dot product
  • Definition 3: Local combination similarity
  • Definition 4: $k$-th order adjacency similarity
  • Definition 5: $r$-Inflation
  • Definition 6: Exponential Inflation
  • Definition 7: Row normalization
  • Definition 8: Column normalization
  • Definition 9: Two-side normalization
  • Definition 10: Global normalization
  • ...and 12 more