Table of Contents
Fetching ...

Are queries and keys always relevant? A case study on Transformer wave functions

Riccardo Rende, Luciano Loris Viteritti

TL;DR

This work investigates whether the canonical queries-keys mechanism of Transformer attention is essential when parametrizing quantum many-body ground states with a Vision Transformer–based neural network state. Through variational Monte Carlo with Stochastic Reconfiguration, the authors compare standard attention (T5/Decoupled) to a Factored, input-independent attention in the $2$D $J_1$-$J_2$ Heisenberg model, finding essentially identical accuracy but reduced cost for the latter. They show that attention weights become input-independent at convergence and provide analytical arguments, including an exact mapping for the Shastry-Sutherland ground state, explaining why large systems favor positional-only connections. The results suggest that, in systems with decaying correlations, queries and keys may be unnecessary, which has practical implications for scaling Transformer-based quantum states and potentially informs attention design in NLP and vision tasks with long sequences.

Abstract

The dot product attention mechanism, originally designed for natural language processing tasks, is a cornerstone of modern Transformers. It adeptly captures semantic relationships between word pairs in sentences by computing a similarity overlap between queries and keys. In this work, we explore the suitability of Transformers, focusing on their attention mechanisms, in the specific domain of the parametrization of variational wave functions to approximate ground states of quantum many-body spin Hamiltonians. Specifically, we perform numerical simulations on the two-dimensional $J_1$-$J_2$ Heisenberg model, a common benchmark in the field of quantum many-body systems on lattice. By comparing the performance of standard attention mechanisms with a simplified version that excludes queries and keys, relying solely on positions, we achieve competitive results while reducing computational cost and parameter usage. Furthermore, through the analysis of the attention maps generated by standard attention mechanisms, we show that the attention weights become effectively input-independent at the end of the optimization. We support the numerical results with analytical calculations, providing physical insights of why queries and keys should be, in principle, omitted from the attention mechanism when studying large systems.

Are queries and keys always relevant? A case study on Transformer wave functions

TL;DR

This work investigates whether the canonical queries-keys mechanism of Transformer attention is essential when parametrizing quantum many-body ground states with a Vision Transformer–based neural network state. Through variational Monte Carlo with Stochastic Reconfiguration, the authors compare standard attention (T5/Decoupled) to a Factored, input-independent attention in the D - Heisenberg model, finding essentially identical accuracy but reduced cost for the latter. They show that attention weights become input-independent at convergence and provide analytical arguments, including an exact mapping for the Shastry-Sutherland ground state, explaining why large systems favor positional-only connections. The results suggest that, in systems with decaying correlations, queries and keys may be unnecessary, which has practical implications for scaling Transformer-based quantum states and potentially informs attention design in NLP and vision tasks with long sequences.

Abstract

The dot product attention mechanism, originally designed for natural language processing tasks, is a cornerstone of modern Transformers. It adeptly captures semantic relationships between word pairs in sentences by computing a similarity overlap between queries and keys. In this work, we explore the suitability of Transformers, focusing on their attention mechanisms, in the specific domain of the parametrization of variational wave functions to approximate ground states of quantum many-body spin Hamiltonians. Specifically, we perform numerical simulations on the two-dimensional - Heisenberg model, a common benchmark in the field of quantum many-body systems on lattice. By comparing the performance of standard attention mechanisms with a simplified version that excludes queries and keys, relying solely on positions, we achieve competitive results while reducing computational cost and parameter usage. Furthermore, through the analysis of the attention maps generated by standard attention mechanisms, we show that the attention weights become effectively input-independent at the end of the optimization. We support the numerical results with analytical calculations, providing physical insights of why queries and keys should be, in principle, omitted from the attention mechanism when studying large systems.
Paper Structure (14 sections, 18 equations, 5 figures, 1 table, 1 algorithm)

This paper contains 14 sections, 18 equations, 5 figures, 1 table, 1 algorithm.

Figures (5)

  • Figure 1: Schematic representation of the attention mechanisms employed in this work: T5 raffel2023exploring (left panel), Decoupled dai2021coatnet (central panel) and Factored bhattacharya2020singlerende2024mapping (right panel) attention. In each of them, relative positional encoding is used. The matrices $Q$, $K$, $V$ and $P$ are referred to queries, keys, values and positional encoding matrix, respectively. Refer to Eqs. \ref{['eq:t5']},\ref{['eq:decoupled']} and \ref{['eq:factored']} in the main text for the analytical expressions.
  • Figure 2: Relative error $\Delta \varepsilon = |(E_0 - E_{\text{ViT}})/E_0|$ during the optimization of the ViT wave function on the $J_1$-$J_2$ Heisenberg model at $J_2/J_1=0$ (left panel) and at $J_2/J_1 = 0.5$ (right) on a ${6 \times 6}$ lattice with periodic boundary conditions. The exact energies $E_{0}$ are computed with exact-diagonalization approaches. The architectures used for the simulations have $h=10$ heads, embedding dimension $d=60$, linear patch size $b=2$, $n_l=1$ layer in panels (a),(c), and $n_l=4$ layers in panels (b),(d). All networks are trained with the same optimization protocol, using SR (see section \ref{['subsec:VMC']}) for $5\times 10^3$ optimization steps with $M=6\times10^3$ samples for the stochastic estimates. Each optimization step corresponds to one iteration in the for loop in Algorithm \ref{['algorithm']}. A cosine decay learning rate scheduler is applied, starting with an initial value of $\tau=0.03$. The optimization curves are consistent across multiple runs with different random initialization of the parameters.
  • Figure 3: Left panel: Relative error $\Delta \varepsilon = |(E_0 - E_{\text{ViT}})/E_0|$ at $J_2/J_1=0.5$ as a function of the system size for ViT architectures with the three different attention mechanisms, namely Factored (orange circles), Decoupled (green squares) and T5 (blue diamonds). The reference ground state energies are taken from exact diagonalization for $L=6$$(-0.503810)$schulz1996 and from variance extrapolation for $L=8$ ($-0.49906$) BeccaGutz2013 and $L=10$ ($-0.497715$) chen2023. Right panel: Time per optimization step (in seconds) measured on a single GPU A100 for the three attention mechanisms as a function of the system size. For all simulations a ViT architecture with hyperparameters $d=10$, $h=10$, $b=2$ and $n_l=4$ is considered. The model is optimized using the SR optimization method for $5 \times 10^3$ steps, employing $M=6\times10^3$ training samples (see section \ref{['subsec:VMC']}). A cosine decay learning rate scheduler is applied, starting with an initial value of $\tau=0.03$.
  • Figure 4: Panel a: Visualizations of the attention maps of a ViT with T5 attention mechanism [see Eq. \ref{['eq:t5']}] for three different input spin configurations. When using initial random parameters there is a clear input dependence in the attention maps (top row). Instead, at the end of the optimization, the attention maps are practically input independent (bottom row). Panel b: Visualizations of the input-dependent term (left panels) and of the input-independent term (right panels) of a ViT with Decoupled attention mechanism [see Eq. \ref{['eq:decoupled']}]. After the optimization (bottom row), the input-dependent term is approximately the identity matrix shifted element-wise by a constant, thus Factored attention is recovered [see Eq. \ref{['eq:factored']}]. In the plots, the input-dependent term has been averaged over $M=6\times10^3$ input configurations sampled from the optimized state. The presented results are obtained by optimizing a ViT architecture with a single layer $n_l=1$, embedding dimension $d=60$ and $h=10$ different heads on a $6\times 6$ lattice at $J_2/J_1=0.5$ (see panel (c) of Fig. \ref{['fig:opt']}). The linear patch size is taken to be $b=2$, thus we have $n=9$ patches and the resulting attention maps have shape $9\times 9$. The plots are obtained by averaging the attention weights over all heads.
  • Figure 5: Graphical representation of the ground state of the Shastry-Sutherland model in the dimer phase shastry1981 on a $6 \times 6$ lattice (periodic boundary connections not shown for clarity). The green shaded regions denote singlet states between two next-nearest neighbors spins. The blue squares $b \times b$ indicate the patches used to construct the input set of vectors for the Transformer.