Table of Contents
Fetching ...

Machine learning with tree tensor networks, CP rank constraints, and tensor dropout

Hao Chen, Thomas Barthel

TL;DR

The work adapts tree tensor networks (TTN) to supervised image classification by enforcing CP rank constraints on every tensor and applying tensor dropout for regularization. The low-rank TTN substantially reduces parameter count and computation while maintaining expressiveness, enabling large branching ratios ($b$) and improved representation power. On MNIST and Fashion-MNIST, the approach achieves competitive test accuracies (e.g., $98.3\%$ on MNIST and up to $90.3\%$ on Fashion-MNIST) while outperforming several tensor-network baselines. The findings highlight a scalable, gradient-stable alternative to deep neural networks with potential for better generalization and cost-efficiency in higher-dimensional data tasks.

Abstract

Tensor networks developed in the context of condensed matter physics try to approximate order-$N$ tensors with a reduced number of degrees of freedom that is only polynomial in $N$ and arranged as a network of partially contracted smaller tensors. As we have recently demonstrated in the context of quantum many-body physics, computation costs can be further substantially reduced by imposing constraints on the canonical polyadic (CP) rank of the tensors in such networks [arXiv:2205.15296]. Here, we demonstrate how tree tensor networks (TTN) with CP rank constraints and tensor dropout can be used in machine learning. The approach is found to outperform other tensor-network-based methods in Fashion-MNIST image classification. A low-rank TTN classifier with branching ratio $b=4$ reaches a test set accuracy of 90.3\% with low computation costs. Consisting of mostly linear elements, tensor network classifiers avoid the vanishing gradient problem of deep neural networks. The CP rank constraints have additional advantages: The number of parameters can be decreased and tuned more freely to control overfitting, improve generalization properties, and reduce computation costs. They allow us to employ trees with large branching ratios, substantially improving the representation power.

Machine learning with tree tensor networks, CP rank constraints, and tensor dropout

TL;DR

The work adapts tree tensor networks (TTN) to supervised image classification by enforcing CP rank constraints on every tensor and applying tensor dropout for regularization. The low-rank TTN substantially reduces parameter count and computation while maintaining expressiveness, enabling large branching ratios () and improved representation power. On MNIST and Fashion-MNIST, the approach achieves competitive test accuracies (e.g., on MNIST and up to on Fashion-MNIST) while outperforming several tensor-network baselines. The findings highlight a scalable, gradient-stable alternative to deep neural networks with potential for better generalization and cost-efficiency in higher-dimensional data tasks.

Abstract

Tensor networks developed in the context of condensed matter physics try to approximate order- tensors with a reduced number of degrees of freedom that is only polynomial in and arranged as a network of partially contracted smaller tensors. As we have recently demonstrated in the context of quantum many-body physics, computation costs can be further substantially reduced by imposing constraints on the canonical polyadic (CP) rank of the tensors in such networks [arXiv:2205.15296]. Here, we demonstrate how tree tensor networks (TTN) with CP rank constraints and tensor dropout can be used in machine learning. The approach is found to outperform other tensor-network-based methods in Fashion-MNIST image classification. A low-rank TTN classifier with branching ratio reaches a test set accuracy of 90.3\% with low computation costs. Consisting of mostly linear elements, tensor network classifiers avoid the vanishing gradient problem of deep neural networks. The CP rank constraints have additional advantages: The number of parameters can be decreased and tuned more freely to control overfitting, improve generalization properties, and reduce computation costs. They allow us to employ trees with large branching ratios, substantially improving the representation power.
Paper Structure (8 sections, 20 equations, 8 figures, 1 table)

This paper contains 8 sections, 20 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: A TTN classifier takes an image ${\boldsymbol{x}}\in\mathbb{R}^N$ as input, transforms it to a feature vector ${\boldsymbol{\Phi}}({\boldsymbol{x}})$ in a much higher-dimensional space $(\mathbb{C}^{2})^{\otimes N}$, and then applies the weight tensor $W:(\mathbb{C}^{2})^{\otimes N}\to\mathbb{C}^L$, realized by a TTN, to obtain the decision function \ref{['eq:decisionFct']}. (a) A TTN classifier with tree branching ratio $b=2$ for 1D images. (b) Tensors $A^{(\tau)}_i$ of intermediate layers are linear maps from $(\mathbb{C}^m)^{\otimes b}$ to $\mathbb{C}^m$. In low-rank TTN classifiers, we reduce computation costs and improve generalization properties by imposing constraints \ref{['eq:CPD']} on the CP ranks of the tensors. (c) 2D image TTN classifier with branching ratio $b=2$. (d) 2D image TTN classifier with $b=4$. See, for example, Refs. Orus2014-349Stoudenmire2016-29 for details on the employed graphical representation for tensor networks.
  • Figure 2: MNIST training histories of full TTN classifiers with branching ratio $b=4$ and a few different bond dimensions $m$. Solid and dashed lines represent the classification accuracy \ref{['eq:accuracy']} for the training set and validation set, respectively.
  • Figure 3: MNIST training histories of low-rank TTN classifiers with branching ratio $b=4$, varying bond dimensions $m$, and CP ranks $r$. Solid and dashed lines represent the classification accuracy \ref{['eq:accuracy']} for the training set and validation set, respectively. For comparison, the black lines show the full TTN results of Fig. \ref{['fig:mnist_full']}.
  • Figure 4: Fashion-MNIST training histories of full TTN classifiers with branching ratio $b=4$ and a few different bond dimensions $m$. Solid and dashed lines represent the classification accuracy \ref{['eq:accuracy']} for the training set and validation set, respectively.
  • Figure 5: Fashion-MNIST training histories of low-rank TTN classifiers with branching ratio $b=4$, varying bond dimensions $m$, and CP ranks $r$. Solid and dashed lines represent the classification accuracy \ref{['eq:accuracy']} for the training set and validation set, respectively. For comparison, the black lines show the full TTN results of Fig. \ref{['fig:fmnist_full']}.
  • ...and 3 more figures