Table of Contents
Fetching ...

Neural Learning of Fast Matrix Multiplication Algorithms: A StrassenNet Approach

Paolo Andreini, Alessandra Bernardi, Monica Bianchini, Barbara Toniella Corradini, Sara Marziali, Giacomo Nunziati, Franco Scarselli

TL;DR

A neural architecture, StrassenNet, is designed, which reproduces the Strassen algorithm for $2\times 2$ multiplication, and across many independent runs the network always converges to a rank-rank tensor, thus numerically recovering Strassen's optimal algorithm.

Abstract

Fast matrix multiplication can be described as searching for low-rank decompositions of the matrix--multiplication tensor. We design a neural architecture, \textsc{StrassenNet}, which reproduces the Strassen algorithm for $2\times 2$ multiplication. Across many independent runs the network always converges to a rank-$7$ tensor, thus numerically recovering Strassen's optimal algorithm. We then train the same architecture on $3\times 3$ multiplication with rank $r\in\{19,\dots,23\}$. Our experiments reveal a clear numerical threshold: models with $r=23$ attain significantly lower validation error than those with $r\le 22$, suggesting that $r=23$ could actually be the smallest effective rank of the matrix multiplication tensor $3\times 3$. We also sketch an extension of the method to border-rank decompositions via an $\varepsilon$--parametrisation and report preliminary results consistent with the known bounds for the border rank of the $3\times 3$ matrix--multiplication tensor.

Neural Learning of Fast Matrix Multiplication Algorithms: A StrassenNet Approach

TL;DR

A neural architecture, StrassenNet, is designed, which reproduces the Strassen algorithm for multiplication, and across many independent runs the network always converges to a rank-rank tensor, thus numerically recovering Strassen's optimal algorithm.

Abstract

Fast matrix multiplication can be described as searching for low-rank decompositions of the matrix--multiplication tensor. We design a neural architecture, \textsc{StrassenNet}, which reproduces the Strassen algorithm for multiplication. Across many independent runs the network always converges to a rank- tensor, thus numerically recovering Strassen's optimal algorithm. We then train the same architecture on multiplication with rank . Our experiments reveal a clear numerical threshold: models with attain significantly lower validation error than those with , suggesting that could actually be the smallest effective rank of the matrix multiplication tensor . We also sketch an extension of the method to border-rank decompositions via an --parametrisation and report preliminary results consistent with the known bounds for the border rank of the matrix--multiplication tensor.
Paper Structure (18 sections, 38 equations, 5 figures, 1 table, 1 algorithm)

This paper contains 18 sections, 38 equations, 5 figures, 1 table, 1 algorithm.

Figures (5)

  • Figure 1: Tensor neural network which computes the row by product matrix multiplication with the classical algorithm.
  • Figure 2: Tensor neural network which computes the row by product matrix multiplication with the Strassen algorithm.
  • Figure 3: Final sub-network which computes the Strassen algorithm.
  • Figure 4: Plot of the training (a) and validation (b) losses of a batch of experiments for matrices with entries clamped in $[-1, 1]$ and values of rank $r$ between $19$ and $23$. In blue, orange, red, green and yellow the mean and standard deviation of the losses for $r = 23, 22, 21, 20, 19$, respectively.
  • Figure 5: Histogram showing the mean $\mu$ and standard deviation $\sigma$ of validation losses for matrices with entries limited to $[-1, 1]$ and rank values $r$ between $19$ and $23$. The value $\mu$ for $r = 23$, in blue, is $0,0025802$, while $\sigma$ is $0,0050602$; for $r = 22$, in orange, $\mu = 0,012773$ and $\sigma = 0,007641$; for $r = 21$, in red, $\mu = 0,029102$ and $\sigma = 0,0056912$; for $r = 20$, in green, $\mu = 0,034695$ and $\sigma = 0,0054031$; for $r = 19$, in yellow, $\mu = 0,047196$ and $\sigma = 0,0085363$.

Theorems & Definitions (10)

  • Definition 1
  • Definition 2
  • Definition 3
  • Definition 4
  • Definition 5
  • Definition 6
  • Definition 7
  • Definition 8
  • Definition 9
  • Definition 10