Table of Contents
Fetching ...

SymTorch: A Framework for Symbolic Distillation of Deep Neural Networks

Elizabeth S. Z. Tan, Adil Soubki, Miles Cranmer

TL;DR

SymTorch is introduced, a library that automates symbolic distillation by wrapping neural network components, collecting their input-output behavior, and approximating them with human-readable equations via PySR, achieving an 8.3% throughput improvement with moderate performance degradation.

Abstract

Symbolic distillation replaces neural networks, or components thereof, with interpretable, closed-form mathematical expressions. This approach has shown promise in discovering physical laws and mathematical relationships directly from trained deep learning models, yet adoption remains limited due to the engineering barrier of integrating symbolic regression into deep learning workflows. We introduce SymTorch, a library that automates this distillation by wrapping neural network components, collecting their input-output behavior, and approximating them with human-readable equations via PySR. SymTorch handles the engineering challenges that have hindered adoption: GPU-CPU data transfer, input-output caching, model serialization, and seamless switching between neural and symbolic forward passes. We demonstrate SymTorch across diverse architectures including GNNs, PINNs and transformer models. Finally, we present a proof-of-concept for accelerating LLM inference by replacing MLP layers with symbolic surrogates, achieving an 8.3\% throughput improvement with moderate performance degradation.

SymTorch: A Framework for Symbolic Distillation of Deep Neural Networks

TL;DR

SymTorch is introduced, a library that automates symbolic distillation by wrapping neural network components, collecting their input-output behavior, and approximating them with human-readable equations via PySR, achieving an 8.3% throughput improvement with moderate performance degradation.

Abstract

Symbolic distillation replaces neural networks, or components thereof, with interpretable, closed-form mathematical expressions. This approach has shown promise in discovering physical laws and mathematical relationships directly from trained deep learning models, yet adoption remains limited due to the engineering barrier of integrating symbolic regression into deep learning workflows. We introduce SymTorch, a library that automates this distillation by wrapping neural network components, collecting their input-output behavior, and approximating them with human-readable equations via PySR. SymTorch handles the engineering challenges that have hindered adoption: GPU-CPU data transfer, input-output caching, model serialization, and seamless switching between neural and symbolic forward passes. We demonstrate SymTorch across diverse architectures including GNNs, PINNs and transformer models. Finally, we present a proof-of-concept for accelerating LLM inference by replacing MLP layers with symbolic surrogates, achieving an 8.3\% throughput improvement with moderate performance degradation.
Paper Structure (80 sections, 24 equations, 8 figures, 10 tables)

This paper contains 80 sections, 24 equations, 8 figures, 10 tables.

Figures (8)

  • Figure 1: A cartoon depicting how SymTorch is used to perform symbolic distillation on model components. For a trained PyTorch model, SymTorch wraps around any NN component in the model. The user passes in sample data and in the forward pass, the inputs and outputs (I/O) of the component are collected. Using PySR, SymTorch performs a SR on the I/O to produce the best expressions approximating the behavior of the NN at different levels of complexity. Optionally, the user can select an equation from the Pareto front and replace the component with this chosen equation in the forward pass producing a hybrid neural-symbolic model.
  • Figure 2: Approximating local model behavior with SLIME. For a complex non-linear model, we choose the point of interest $\mathbf{x}^*$. We sample points around this region and fit a symbolic model to these points.
  • Figure 3: Framework for reducing inference compute in transformer models by replacing MLP layers with symbolic surrogate models. Inputs to the MLP layer undergo dimensionality reduction via PCA. A symbolic model maps the inputs to the outputs. Output activations have their dimensionality increased, again through PCA, to match model dimensionality.
  • Figure 4: Change in test set perplexity under PCA compression and reconstruction of MLP activations relative to the baseline perplexity of 10.62. For layers 7, 14, and 21, the MLP inputs are projected to a lower-dimensional subspace via PCA and then reconstructed prior to the MLP, while the MLP outputs are similarly projected and reconstructed before being passed to the remainder of the model.
  • Figure 5: Inference throughput (in tokens per second) versus perplexity on the test set for various language models. The test set is made up of a random chunk of 175k characters from the Wikitext-2-v1 dataset.
  • ...and 3 more figures