Table of Contents
Fetching ...

Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics

Hanlin Zhu, Baihe Huang, Shaolun Zhang, Michael Jordan, Jiantao Jiao, Yuandong Tian, Stuart Russell

TL;DR

This paper theoretically analyze the reversal curse via the training dynamics of (stochastic) gradient descent for two auto-regressive models: a bilinear model that can be viewed as a simplification of a one-layer transformer; (2) one-layer transformers under certain assumptions.

Abstract

Auto-regressive large language models (LLMs) show impressive capacities to solve many complex reasoning tasks while struggling with some simple logical reasoning tasks such as inverse search: when trained on '$A \to B$' (e.g., 'Tom is the parent of John'), LLM fails to directly conclude '$B \gets A$' (e.g., 'John is the child of Tom') during inference even if the two sentences are semantically identical, which is known as the 'reversal curse'. In this paper, we theoretically analyze the reversal curse via the training dynamics of (stochastic) gradient descent for two auto-regressive models: (1) a bilinear model that can be viewed as a simplification of a one-layer transformer; (2) one-layer transformers under certain assumptions. Our analysis reveals that for both models, the reversal curse is a consequence of the (effective) model weights 'asymmetry', i.e., the increase of weights from a token $A$ to token $B$ during training does not necessarily cause the increase of the weights from $B$ to $A$, which is caused by the training dynamics under certain choice of loss function and the optimization space of model parameters. Moreover, our analysis can be naturally applied to other logical reasoning tasks such as chain-of-thought (COT), which provides a new perspective different from previous work that focuses on expressivity. Finally, we conduct experiments to validate our theory on multi-layer transformers under different settings. Our code is available at https://github.com/marlo-z/reversal_curse_analysis/.

Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics

TL;DR

This paper theoretically analyze the reversal curse via the training dynamics of (stochastic) gradient descent for two auto-regressive models: a bilinear model that can be viewed as a simplification of a one-layer transformer; (2) one-layer transformers under certain assumptions.

Abstract

Auto-regressive large language models (LLMs) show impressive capacities to solve many complex reasoning tasks while struggling with some simple logical reasoning tasks such as inverse search: when trained on '' (e.g., 'Tom is the parent of John'), LLM fails to directly conclude '' (e.g., 'John is the child of Tom') during inference even if the two sentences are semantically identical, which is known as the 'reversal curse'. In this paper, we theoretically analyze the reversal curse via the training dynamics of (stochastic) gradient descent for two auto-regressive models: (1) a bilinear model that can be viewed as a simplification of a one-layer transformer; (2) one-layer transformers under certain assumptions. Our analysis reveals that for both models, the reversal curse is a consequence of the (effective) model weights 'asymmetry', i.e., the increase of weights from a token to token during training does not necessarily cause the increase of the weights from to , which is caused by the training dynamics under certain choice of loss function and the optimization space of model parameters. Moreover, our analysis can be naturally applied to other logical reasoning tasks such as chain-of-thought (COT), which provides a new perspective different from previous work that focuses on expressivity. Finally, we conduct experiments to validate our theory on multi-layer transformers under different settings. Our code is available at https://github.com/marlo-z/reversal_curse_analysis/.
Paper Structure (51 sections, 22 theorems, 115 equations, 18 figures, 3 tables)

This paper contains 51 sections, 22 theorems, 115 equations, 18 figures, 3 tables.

Key Result

Theorem 1

Fix any $\delta,\epsilon \in (0,1)$. For small $\sigma$ and $d \geq \operatorname{poly}(n,m,1/\epsilon,\log(1/\delta))$, with probability at least $1-\delta$, we have

Figures (18)

  • Figure 1: Experiment results of reversal curse under default configuration (see \ref{['app:model_data_configs_table']}). The curves represent the (average) negative log probability of the model predicting the next token to be $\texttt{B}_i$ when the input is "$\texttt{A}_i\to$", or to be $\texttt{A}_i$ when the input is "$\texttt{B}_i\gets$". While the sentences in the training set can be learned nearly perfectly (as shown by the training curve where the next token probability converges to one), the model is not able to predict the correct next token in the validation set better than a uniformly random guess. Both curves are averaged over 10 random seeds.
  • Figure 2: Visualization of the weights (logits) of the model with default configurations trained after 3000 epochs for the reversal curse experiment. For the top-left matrix, the $i$-th row corresponds to an entity token $\texttt{A}_i$ for a training pair, and the $i$-th column corresponds to an entity token $\texttt{B}_i$ for a training pair. The $(i,j)$-th entry represents the model weights from the token $\texttt{A}_i$ to $\texttt{B}_j$, i.e., the logits of $\texttt{B}_j$ when the input sequence consists of only $\texttt{A}_i$. Similarly, for the bottom-left matrix, the row corresponds to the input entity tokens of the seen direction (the direction included in the training set) of validation pairs, and the column corresponds to output entity tokens. The two matrices on the right are obtained by swapping row tokens and column tokens of their corresponding left matrices. Note that the diagonals of the bottom-right matrix are all close to zero, while the diagonals of other matrices all have large values. This implies that if a pair of tokens $(\texttt{A}, \texttt{B})$ only appear in the training set in one direction, then the model weights associated with the other direction will hardly get trained.
  • Figure 3: Experiment results of COT under default configuration (see \ref{['app:model_data_configs_table']}). The curves represent the (average) negative log probability of the model predicting the next token to be: (1) $\texttt{B}_i$ given the input "$\texttt{A}_i \to$", (2) $\texttt{C}_i$ given the input "$\texttt{B}_i \to$", or (3) $\texttt{C}_i$ given the input "$\texttt{A}_i \leadsto$". Similar to the reversal curse experiment, while the sentences in the training set can be learned nearly perfectly, the model is not able to predict the correct next token in the validation set better than a uniformly random guess. Both curves are averaged over 10 random seeds.
  • Figure 4: Visualization of the weights (logits) of the model with default configurations trained after 3000 epochs for COT experiment. The matrices are similar to \ref{['fig:reverse_logits']}. The row tokens for the top matrices are $\texttt{A}_i$, $\texttt{B}_i$, $\texttt{A}_i$ and column tokens are $\texttt{B}_i$, $\texttt{C}_i$, $\texttt{C}_i$ for training triples respectively. Similarly, the bottom matrices correspond to validation triples. For validation triples ($\texttt{A}_i, \texttt{B}_i, \texttt{C}_i$), the weights from $\texttt{A}_i$ to $\texttt{C}_i$ get hardly trained as indicated by the diagonals of the last matrix.
  • Figure 5: Results for reversal curse for different vocabulary sizes. All other configurations are set as default values as in \ref{['app:model_data_configs_table']}. The training set sizes for the above four experiments are $9$, $20$, $85$, $850$ respectively, and the validation set sizes are $1$, $4$, $15$, $150$ respectively.
  • ...and 13 more figures

Theorems & Definitions (40)

  • Theorem 1: Separation of training dynamics (informal statement of \ref{['thm:dynamics-bilinear']})
  • Theorem 2: Lower bound of reversal loss (informal statement of \ref{['coro:dynamics-bilinear']}
  • Lemma 1: Gradient of $Y$ and $Z$ for 1-layer transformer, Lemma 1 of tian2023scan
  • proof
  • Proposition 4.1: Initial probability under zero initializaion
  • Proposition 4.2: Next token probability
  • Lemma 2: Dynamics of $Y(t)$
  • Theorem 3: Reversal curse
  • Theorem 4: Importance of chain-of-thought, informal statement of \ref{['prop:cot']}
  • Theorem 5: Separation of training dynamics, formal statement of \ref{['thm:dynamics-bilinear-informal']}
  • ...and 30 more