Table of Contents
Fetching ...

On Limitation of Transformer for Learning HMMs

Jiachen Hu, Qinghua Liu, Chi Jin

TL;DR

This work benchmarks Transformer and RNN capabilities on learning Hidden Markov Models (HMMs) and their variants, revealing that Transformers typically lag RNNs in training speed and final accuracy across both belief-state inference and next-observation prediction. It introduces block Chain-of-Thought (block CoT) training to extend the sequence lengths that shallow Transformers can handle, at the cost of additional computation, and provides theoretical results showing Transformer expressiveness grows with depth, able to represent HMMs of length up to $2^L$ for $L$ layers in certain settings. The study systematically constructs fast- and slow-mixing HMMs, including deterministic and stochastic cyclic variants, to expose limitations and scaling laws of Transformers, and demonstrates that curriculum scheduling further improves training dynamics. Overall, the paper highlights fundamental trade-offs in Transformer-based sequential modeling for simple probabilistic models and suggests practical strategies to mitigate limitations in long-horizon tasks.

Abstract

Despite the remarkable success of Transformer-based architectures in various sequential modeling tasks, such as natural language processing, computer vision, and robotics, their ability to learn basic sequential models, like Hidden Markov Models (HMMs), is still unclear. This paper investigates the performance of Transformers in learning HMMs and their variants through extensive experimentation and compares them to Recurrent Neural Networks (RNNs). We show that Transformers consistently underperform RNNs in both training speed and testing accuracy across all tested HMM models. There are even challenging HMM instances where Transformers struggle to learn, while RNNs can successfully do so. Our experiments further reveal the relation between the depth of Transformers and the longest sequence length it can effectively learn, based on the types and the complexity of HMMs. To address the limitation of transformers in modeling HMMs, we demonstrate that a variant of the Chain-of-Thought (CoT), called $\textit{block CoT}$ in the training phase, can help transformers to reduce the evaluation error and to learn longer sequences at a cost of increasing the training time. Finally, we complement our empirical findings by theoretical results proving the expressiveness of transformers in approximating HMMs with logarithmic depth.

On Limitation of Transformer for Learning HMMs

TL;DR

This work benchmarks Transformer and RNN capabilities on learning Hidden Markov Models (HMMs) and their variants, revealing that Transformers typically lag RNNs in training speed and final accuracy across both belief-state inference and next-observation prediction. It introduces block Chain-of-Thought (block CoT) training to extend the sequence lengths that shallow Transformers can handle, at the cost of additional computation, and provides theoretical results showing Transformer expressiveness grows with depth, able to represent HMMs of length up to for layers in certain settings. The study systematically constructs fast- and slow-mixing HMMs, including deterministic and stochastic cyclic variants, to expose limitations and scaling laws of Transformers, and demonstrates that curriculum scheduling further improves training dynamics. Overall, the paper highlights fundamental trade-offs in Transformer-based sequential modeling for simple probabilistic models and suggests practical strategies to mitigate limitations in long-horizon tasks.

Abstract

Despite the remarkable success of Transformer-based architectures in various sequential modeling tasks, such as natural language processing, computer vision, and robotics, their ability to learn basic sequential models, like Hidden Markov Models (HMMs), is still unclear. This paper investigates the performance of Transformers in learning HMMs and their variants through extensive experimentation and compares them to Recurrent Neural Networks (RNNs). We show that Transformers consistently underperform RNNs in both training speed and testing accuracy across all tested HMM models. There are even challenging HMM instances where Transformers struggle to learn, while RNNs can successfully do so. Our experiments further reveal the relation between the depth of Transformers and the longest sequence length it can effectively learn, based on the types and the complexity of HMMs. To address the limitation of transformers in modeling HMMs, we demonstrate that a variant of the Chain-of-Thought (CoT), called in the training phase, can help transformers to reduce the evaluation error and to learn longer sequences at a cost of increasing the training time. Finally, we complement our empirical findings by theoretical results proving the expressiveness of transformers in approximating HMMs with logarithmic depth.
Paper Structure (64 sections, 10 theorems, 62 equations, 4 figures, 3 tables)

This paper contains 64 sections, 10 theorems, 62 equations, 4 figures, 3 tables.

Key Result

Theorem 5.1

For a deterministic HMM with state space size $n$ and observation space size $m$, there exists a single layer RNN model with embedding dimension $d = O(nm)$ and ReLU activation approximating the belief state sequence of length $T$ with no error. The $\ell_\infty$ norm of the parameters of the RNN mo

Figures (4)

  • Figure 1: An illustration of CyclicHMM-DET model and CyclicHMM-HARD model. Left: A CyclicHMM-DET model with 4 states and 2 actions. The transition graph of each action is a cyclic permutation over the state space. Different actions may induce different cyclic permutation. Right: Given a CyclicHMM-RND or CyclicHMM-DET model, the CyclicHMM-HARD model transforms it into a larger HMM. The transition in CyclicHMM-HARD model always goes from stage 1 to 3 then back to stage 1. The dotted line denotes a stochastic transition from stage 1 to stage 2 with probability $\alpha$, and the solid line denotes deterministic transition. States in stage 2 always emit a signal observation $*$ indicating the entrance of stage 3, and states in stage 3 emit the current state as observation.
  • Figure 2: The evaluation loss at a specific sequence length of neural networks for 4 HMMs. To illustrate the difference between RNNs and Transformers of different depth, we choose the evaluation sequence length as 10, 30, 30, 120 for 4 tasks from left to right respectively. The evaluation loss of CyclicHMM-HARD model only considers the states at prediction stage since the prediction for other stages is simply a constant. The convergence speed and final accuracy of RNN are at least as good as all Transformers, which are strictly better in many cases.
  • Figure 3: The left and middle figure: The approximate scaling between fit length of error rate 0.05/0.1 and the depth of the Transformer. "State" denotes a belief state inference task, and "observation" denotes a next-observation prediction task. There are roughly 3 different scaling pattern for the tasks, which is affected by the mixing time and hidden information in training data. The closeness between 0.05-fit length and 0.1-fit length reflects a small possibility of optimization caveats in the curves. The right figure: The evaluation loss at length 60 for 8/12 block CoT training for 3-layer Transformer on CyclicHMM-DET task and 4-layer Transformer on MatMul task. None of them has fit length 60 (dashed curves) without block CoT. The evaluation loss reduces dramatically with 8/12 block CoT, where 8/12 is approximately the half of their 0.05-fit length.
  • Figure 4: The benefits of curriculum training. Left: The $0.1$-fit length of Transformers with different depth, reported as the best of 2 experiments with different seed. The fit length of curriculum training is comparable with vanilla training in CyclicHMM-DET in general, and better than vanilla training in MatMul. Right: The convergence speed for different Transformers on MatMul model. The convergence of $0.1$-fit length for curriculum training is consistently faster than vanilla training.

Theorems & Definitions (17)

  • Remark 3.1: Why is CyclicHMM-HARD model hard?
  • Theorem 5.1
  • Theorem 5.2
  • proof : Proof Sketch.
  • Theorem 5.3
  • proof : Proof Sketch.
  • Remark 5.4
  • Proposition A.1
  • proof
  • Lemma C.1: Lemma C.1 of feng2023towards
  • ...and 7 more