Table of Contents
Fetching ...

Learning to Unscramble: Simplifying Symbolic Expressions via Self-Supervised Oracle Trajectories

David Shih

Abstract

We present a new self-supervised machine learning approach for symbolic simplification of complex mathematical expressions. Training data is generated by scrambling simple expressions and recording the inverse operations, creating oracle trajectories that provide both goal states and explicit paths to reach them. A permutation-equivariant, transformer-based policy network is then trained on this data step-wise to predict the oracle action given the input expression. We demonstrate this approach on two problems in high-energy physics: dilogarithm reduction and spinor-helicity scattering amplitude simplification. In both cases, our trained policy network achieves near perfect solve rates across a wide range of difficulty levels, substantially outperforming prior approaches based on reinforcement learning and end-to-end regression. When combined with contrastive grouping and beam search, our model achieves a 100\% full simplification rate on a representative selection of 5-point gluon tree-level amplitudes in Yang-Mills theory, including expressions with over 200 initial terms.

Learning to Unscramble: Simplifying Symbolic Expressions via Self-Supervised Oracle Trajectories

Abstract

We present a new self-supervised machine learning approach for symbolic simplification of complex mathematical expressions. Training data is generated by scrambling simple expressions and recording the inverse operations, creating oracle trajectories that provide both goal states and explicit paths to reach them. A permutation-equivariant, transformer-based policy network is then trained on this data step-wise to predict the oracle action given the input expression. We demonstrate this approach on two problems in high-energy physics: dilogarithm reduction and spinor-helicity scattering amplitude simplification. In both cases, our trained policy network achieves near perfect solve rates across a wide range of difficulty levels, substantially outperforming prior approaches based on reinforcement learning and end-to-end regression. When combined with contrastive grouping and beam search, our model achieves a 100\% full simplification rate on a representative selection of 5-point gluon tree-level amplitudes in Yang-Mills theory, including expressions with over 200 initial terms.
Paper Structure (24 sections, 16 equations, 6 figures, 2 tables)

This paper contains 24 sections, 16 equations, 6 figures, 2 tables.

Figures (6)

  • Figure 1: Architecture of the policy network for symbolic simplification. Each term's feature vector is embedded and processed by a Transformer encoder with a prepended learnable [CLS] token. No positional encoding is used, respecting the permutation symmetry of terms. The permutation equivariant policy head takes the output of the transformer and returns probabilities (softmax) over the action space.
  • Figure 2: Solve rate vs. scramble depth for dilogarithm simplification. Our model (blue) maintains near-100% performance under both source-relative and target-relative criteria, even beyond the training range (shaded). The seq2seq model of DSZ (orange) degrades at higher scramble depths.
  • Figure 3: Average number of steps to solve vs. scramble depth for dilogarithm simplification, with $\pm$1 standard deviation error bars. The dashed line marks $y=x$ (steps equal to scramble depth); the shaded region indicates the training range (scramble depths 1--7). The model consistently finds shorter paths than the scramble depth, indicating it learns to bypass redundancy introduced by the scrambling process.
  • Figure 4: Solve rate vs. number of target terms for 4-point (left), 5-point (center), and 6-point (right) amplitudes. Performance is shown under source-relative (dashed) and target-relative (solid) criteria for our model (blue) and CDS with $B{=}20$ (red). Gray bars show the number of test samples at each target term count.
  • Figure 5: Solve rate vs. source bracket count for 4-point (left), 5-point (center), and 6-point (right) amplitudes. Our model (blue) maintains near-100% solve rates across all starting complexities under both source-relative (dashed) and target-relative (solid) criteria. The model of CDS with beam size $B{=}20$ (red) degrades with increasing bracket count. Gray bars show the number of test samples at each bracket count.
  • ...and 1 more figures