Table of Contents
Fetching ...

Eureka-Moments in Transformers: Multi-Step Tasks Reveal Softmax Induced Optimization Problems

David T. Hoffmann, Simon Schrodi, Jelena Bratulić, Nadine Behrmann, Volker Fischer, Thomas Brox

TL;DR

The paper investigates why transformers exhibit Eureka-moments—sudden, rapid improvements—in multi-step tasks, focusing on a controlled synthetic two-step setup where the intermediate cue must be inferred from the first sub-task to solve the second. It identifies Softmax-induced optimization issues as the root cause: ill-distributed attention yields vanishing gradients for $W_Q$ and $W_K$, preventing effective use of intermediate information. By applying targeted interventions such as NormSoftmax and Heat Treatment (HT), and by tuning the attention temperature $\tau$, the authors demonstrate substantially faster convergence, higher final accuracy, and greater robustness across hyperparameters, with results that transfer to real data tasks like RoBERTa language modeling and GPT-2 ICL. The work provides both mechanistic insights and practical remedies for improving multi-step reasoning in transformers, with a public release of code and datasets. $p(y|x,z)\cdot p(z|x)$ frames the core multi-step objective, while the attention mechanism $\text{Attention}(Q,K,V)=\text{Softmax}(QK^T/\tau)V$ underpins the identified optimization bottleneck.$

Abstract

In this work, we study rapid improvements of the training loss in transformers when being confronted with multi-step decision tasks. We found that transformers struggle to learn the intermediate task and both training and validation loss saturate for hundreds of epochs. When transformers finally learn the intermediate task, they do this rapidly and unexpectedly. We call these abrupt improvements Eureka-moments, since the transformer appears to suddenly learn a previously incomprehensible concept. We designed synthetic tasks to study the problem in detail, but the leaps in performance can be observed also for language modeling and in-context learning (ICL). We suspect that these abrupt transitions are caused by the multi-step nature of these tasks. Indeed, we find connections and show that ways to improve on the synthetic multi-step tasks can be used to improve the training of language modeling and ICL. Using the synthetic data we trace the problem back to the Softmax function in the self-attention block of transformers and show ways to alleviate the problem. These fixes reduce the required number of training steps, lead to higher likelihood to learn the intermediate task, to higher final accuracy and training becomes more robust to hyper-parameters.

Eureka-Moments in Transformers: Multi-Step Tasks Reveal Softmax Induced Optimization Problems

TL;DR

The paper investigates why transformers exhibit Eureka-moments—sudden, rapid improvements—in multi-step tasks, focusing on a controlled synthetic two-step setup where the intermediate cue must be inferred from the first sub-task to solve the second. It identifies Softmax-induced optimization issues as the root cause: ill-distributed attention yields vanishing gradients for and , preventing effective use of intermediate information. By applying targeted interventions such as NormSoftmax and Heat Treatment (HT), and by tuning the attention temperature , the authors demonstrate substantially faster convergence, higher final accuracy, and greater robustness across hyperparameters, with results that transfer to real data tasks like RoBERTa language modeling and GPT-2 ICL. The work provides both mechanistic insights and practical remedies for improving multi-step reasoning in transformers, with a public release of code and datasets. frames the core multi-step objective, while the attention mechanism underpins the identified optimization bottleneck.$

Abstract

In this work, we study rapid improvements of the training loss in transformers when being confronted with multi-step decision tasks. We found that transformers struggle to learn the intermediate task and both training and validation loss saturate for hundreds of epochs. When transformers finally learn the intermediate task, they do this rapidly and unexpectedly. We call these abrupt improvements Eureka-moments, since the transformer appears to suddenly learn a previously incomprehensible concept. We designed synthetic tasks to study the problem in detail, but the leaps in performance can be observed also for language modeling and in-context learning (ICL). We suspect that these abrupt transitions are caused by the multi-step nature of these tasks. Indeed, we find connections and show that ways to improve on the synthetic multi-step tasks can be used to improve the training of language modeling and ICL. Using the synthetic data we trace the problem back to the Softmax function in the self-attention block of transformers and show ways to alleviate the problem. These fixes reduce the required number of training steps, lead to higher likelihood to learn the intermediate task, to higher final accuracy and training becomes more robust to hyper-parameters.
Paper Structure (30 sections, 7 equations, 23 figures, 11 tables)

This paper contains 30 sections, 7 equations, 23 figures, 11 tables.

Figures (23)

  • Figure 1: Transformers can get stuck during optimization for two-stage tasks.(a) Describes our 2 step decision task used to study Eureka-moments. Task 1 is to compare the two indicators (here digits). If the digits are the same, task 2 is to classify the top-right image and bottom left else. Top-right and bottom left are referred to as targets. The location of the correct target is referred to as target location. (b)Validation accuracy and training loss for the task in (a). 2 ViTs (blue and green) fail to converge, while one ViT (yellow) has a Eureka-moment. Eureka-moments are characterised by a sudden increase of accuracy and drop of the loss (in contrast to Grokking grokking). ResNets are not susceptible to this kind of optimization difficulty. (c)Eureka-moments on real datasets. Sharp improvements after initial plateauing can also be observed for GPT-2 ICL, here in the Omniglot ICL task chan2022data and language modeling with RoBERTa on Wikipedia. We will show later that our analysis transfers to these tasks (see \ref{['fig:real_life_examples_normsoft']}).
  • Figure 2: What is represented in different parts of the attention block. Bar plots show linear probe accuracy averaged over heads. Indicator 1 is the top MNIST digit. Both ViT and ViT $\tau=\frac{1}{3}$ extract the indicator class information from the images and it is available in each layer. Information is available before and after the residual connection, therefore it is not entirely ignored by the attention. Differences between ViT and ViT $\tau=\frac{1}{3}$ are visible for CLS token and target location task. Res. denotes residual layer. Error bars show variance over heads. Results for layer 6 using $Z_i$. Black line indicates chance. Indicator 2 plots are similar. More layers and indicator 2 plots are shown in Fig. \ref{['fig:linprobe_z']}. $Q$, $K$, $V$ linear probes in Figs. \ref{['fig:linprobe_q']} to \ref{['fig:linprobe_v']}. A similar analysis using more sensitive information-theoretic probes voita2020information can be found in \ref{['sec:info_probes']}
  • Figure 3: Attention maps after training for:(a)ViT without Eureka-moment. It fails to compare the 2 digits. First layers explicitly ignore indicators (digits) (highlighted with red). (b) ViT $\tau=\frac{1}{3}$ with Eureka-moment attends indicators in first layers (red) and predominantly attends the correct target (ankle boot) in later layers. Black is no and white is high attention. Maps show the average attention of each query, i.e., we average over the key-dimension of the attention map.
  • Figure 4: L1 gradient norm during training for $W_K$, $W_Q$ and $W_V$ for the first layer. For ViT, $W_K$ and $W_Q$ receive much smaller gradients than $W_V$. Before Eureka-moment (gray regions), the differences between gradient magnitudes are much smaller for smaller temperatures or NormSoftmax. The y-axis is log scaled. All layers shown in Fig. \ref{['fig:grad_mag_full']}.
  • Figure 5: Gradients on image for $W_k$ at Epoch 50. For ViT the gradient for $W_K$ comes mostly from target regions, while for the other approaches indicator regions provide substantial gradient. A detailed explanation of this plot and plots for $Q$ and $V$ can be found in Sec. \ref{['sec:split_grad_indi_target']}
  • ...and 18 more figures