Table of Contents
Fetching ...

ClassiFIM: An Unsupervised Method To Detect Phase Transitions

Victor Kasatkin, Evgeny Mozgunov, Nicholas Ezzell, Utkarsh Mishra, Itay Hen, Daniel Lidar

Abstract

Estimation of the Fisher Information Metric (FIM-estimation) is an important task that arises in unsupervised learning of phase transitions, a problem proposed by physicists. This work completes the definition of the task by defining rigorous evaluation metrics distMSE, distMSEPS, and distRE and introduces ClassiFIM, a novel machine learning method designed to solve the FIM-estimation task. Unlike existing methods for unsupervised learning of phase transitions, ClassiFIM directly estimates a well-defined quantity (the FIM), allowing it to be rigorously compared to any present and future other methods that estimate the same. ClassiFIM transforms a dataset for the FIM-estimation task into a dataset for an auxiliary binary classification task and involves selecting and training a model for the latter. We prove that the output of ClassiFIM approaches the exact FIM in the limit of infinite dataset size and under certain regularity conditions. We implement ClassiFIM on multiple datasets, including datasets describing classical and quantum phase transitions, and find that it achieves a good ground truth approximation with modest computational resources. Furthermore, we independently implement two alternative state-of-the-art methods for unsupervised estimation of phase transition locations on the same datasets and find that ClassiFIM predicts such locations at least as well as these other methods. To emphasize the generality of our method, we also propose and generate the MNIST-CNN dataset, which consists of the output of CNNs trained on MNIST for different hyperparameter choices. Using ClassiFIM on this dataset suggests there is a phase transition in the distribution of image-prediction pairs for CNNs trained on MNIST, demonstrating the broad scope of FIM-estimation beyond physics.

ClassiFIM: An Unsupervised Method To Detect Phase Transitions

Abstract

Estimation of the Fisher Information Metric (FIM-estimation) is an important task that arises in unsupervised learning of phase transitions, a problem proposed by physicists. This work completes the definition of the task by defining rigorous evaluation metrics distMSE, distMSEPS, and distRE and introduces ClassiFIM, a novel machine learning method designed to solve the FIM-estimation task. Unlike existing methods for unsupervised learning of phase transitions, ClassiFIM directly estimates a well-defined quantity (the FIM), allowing it to be rigorously compared to any present and future other methods that estimate the same. ClassiFIM transforms a dataset for the FIM-estimation task into a dataset for an auxiliary binary classification task and involves selecting and training a model for the latter. We prove that the output of ClassiFIM approaches the exact FIM in the limit of infinite dataset size and under certain regularity conditions. We implement ClassiFIM on multiple datasets, including datasets describing classical and quantum phase transitions, and find that it achieves a good ground truth approximation with modest computational resources. Furthermore, we independently implement two alternative state-of-the-art methods for unsupervised estimation of phase transition locations on the same datasets and find that ClassiFIM predicts such locations at least as well as these other methods. To emphasize the generality of our method, we also propose and generate the MNIST-CNN dataset, which consists of the output of CNNs trained on MNIST for different hyperparameter choices. Using ClassiFIM on this dataset suggests there is a phase transition in the distribution of image-prediction pairs for CNNs trained on MNIST, demonstrating the broad scope of FIM-estimation beyond physics.
Paper Structure (43 sections, 1 theorem, 26 equations, 5 figures, 3 tables, 2 algorithms)

This paper contains 43 sections, 1 theorem, 26 equations, 5 figures, 3 tables, 2 algorithms.

Key Result

Theorem 1

Consider a statistical manifold-like pair $(\mathcal{M}, P)$ s.t. $\mathcal{M} \subset \mathbb{R}^m$ and the space $\Omega$ of all possible samples $x$ is finite. Let ${\boldsymbol{\lambda}}_0$ be a point in the interior of $\mathcal{M}$, and $M^*$ be a model, satisfying the following conditions: Then

Figures (5)

  • Figure 1: Illustration of the differences between the outputs of three unsupervised ML methods: W vanNieuwenburg2016LearningPT, SPCA Huang:22, and ClassiFIM, when applied to the same dataset for IsNNN400 SM. Panels (a)--(c) show 1D phase diagrams generated by W, SPCA, and ClassiFIM, respectively, for $\lambda_1 = 48/64$ [the slice indicated by the black line in panels (d)--(g)]. The vertical dotted lines indicate the maxima of the ground truth FIM shown in panels (c) and (g). Panels (d)--(f) show 2D phase diagrams generated by W, SPCA, and ClassiFIM, respectively. Panel (g) shows the ground truth FIM. Note that the scales on the diagrams (a)--(c), as well as color schemes on the diagrams (d)--(g), are different, reflecting the different meanings of the outputs of the methods: the W method produces the accuracy of "mislabelled" samples, which is expected to achieve a peak with a value close to $1.0$ at the locations of phase transitions; SPCA produces the components of a kernel principal components analysis (PCA) which are expected to change rapidly at the locations of phase transitions; ClassiFIM produces an estimate of the Fisher Information Metric, reflecting the rate of change of the underlying probability distribution. For more details see \ref{['as:prior-work']} (\ref{['fig:isnnn400-w-details', 'fig:isnnn400-cf-details']}).
  • Figure 2: The neural network architecture we used for the Hubbard12 and FIL24 datasets. Here "Conv" layers are graph convolutional layers for the $12$-site lattice with two different types of edges.
  • Figure 3: ClassiFIM output (left) compared with properties of the outputs of the models trained on MNIST. From left to right, these are $1 - \mathrm{accuracy}$, $1 - \mathbb{E}\hat{p}_y$, where $\hat{p}_y$ is the predicted probability of the correct class, and the mean entropy of the predicted class probabilities. For these three plots, the averages are computed over all images in the training set, and over $14$ trained models for each ${\boldsymbol{\lambda}}$. The axes correspond to $\mathrm{max\_lr} = 10^{-3+2\lambda_0}$ and $\beta_1 = 1 - 10^{-4 \lambda_1}$. The white circle indicates the $\approx 99\%$ accuracy model from tuomaso2021mnist.
  • Figure 4: Illustration of mod-W. Panel (a): Accuracy plot generated by the W method for horizontal slices, assembled into a 2D phase diagram (identical to \ref{['fig:isnnn400-1x7']}(d)). Accuracy values below $75\%$ are truncated to $75\%$ and represented in dark blue. Panel (b): The plot after post-processing, showing the removal of spurious low-accuracy points. Ground truth and predicted peaks are overlaid on the 2D diagram. Each horizontal slice contains $n_s'$ black circles representing the "inner" ground truth peaks, which are the targets for prediction by methods like mod-W. Additional $n_s - n_s'$ light squares in each slice represent ground truth peaks which are either too shallow or too close to the border, and their predictions are not necessary in the PeakRMSE metric given by \ref{['eq:peakrmse']}, but their count $n_s - n_s'$ represents the additional number of guesses each method like mod-W is allowed to make in that slice. The $n_s$ green crosses in each slice indicate these guesses made by mod-W. For each black circle in each slice, there is a corresponding contribution to the PeakRMSE metric equal to the distance to the nearest green cross in that slice (or the border of the slice if it is closer). Panel (c): Accuracy plot generated by method W for vertical slices within the same computational budget. In mod-W, we then post-process it and predict the peaks (not shown, the process is the same as for horizontal slices). Panel (d): A single horizontal slice at $\lambda_1 = 18/64$ (marked on panels (a) and (b) with a horizontal line) with W plot before and after post-processing, i.e., accuracy as a function of $\lambda_0$. Vertical dotted lines mark the ground truth peaks. For this dataset and slice, the predicted peaks (i.e., local maxima of the mod-W plot) align perfectly with the ground truth. On panels (a)--(c) the axes are $\lambda_0$ (horizontal) and $\lambda_1$ (vertical) ranging from $0$ to $1$.
  • Figure 5: Illustration of mod-ClassiFIM. Panel (a): The ClassiFIM prediction from \ref{['fig:isnnn400-1x7']}(f). Darker colors denote higher values of $\mathop{\mathrm{Tr}}\nolimits(\hat{g}({\boldsymbol{\lambda}}))$. To compute the intensity $I_c$ of color channels $c \in \{\mathrm{red}, \mathrm{green}, \mathrm{blue}\}$, we selected three unit vectors ${\boldsymbol{v}}_c$, each separated by $2\pi/3$, and computed $I_c = 1 - \max(1, \sqrt{\hat{g}({\boldsymbol{\lambda}}; {\boldsymbol{v}}_c)} / C)$ where $C = 140$ serves as a normalization constant and $\hat{g}({\boldsymbol{\lambda}}; {\boldsymbol{v}}) = \sum_{\mu\nu} \hat{g}_{\mu\nu}({\boldsymbol{\lambda}})v_{\mu} v_{\nu}$ is the squared length of vector ${\boldsymbol{v}}$ as per the metric $\hat{g}({\boldsymbol{\lambda}})$. This color scheme was also used in panels (f) and (g) of \ref{['fig:isnnn400-1x7']}. Panel (b): Semidisk illustrating the color scheme. For each point A in the semidisk, we define a vector ${\boldsymbol{v}} = \overrightarrow{OA}$ (an example of such a vector is shown). Then the color of such a point A represents $g$ defined using $g_{\mu\nu} = C^2 v_{\mu} v_{\nu}$. For instance, the color of point O is white because it corresponds to ${\boldsymbol{v}} = 0$ and, hence, $g=0$. Such a semicircle only shows the colors corresponding to rank-1 tensors $g$: for example, the color of $g = C^2 I$ would be black (not shown). Panel (c): $\hat{g}_{00}$, the metric component to be postprocessed in the horizontal slices. Panel (d): Post-processed $\hat{g}_{00}$. Ground truth and predicted peaks are overlaid on the 2D diagram similarly to panel (b) of \ref{['fig:isnnn400-w-details']}. Panel (e): A single horizontal slice at $\lambda_1 = 18/64$ (indicated in panels (b) and (c) using a horizontal line) showing the raw and post-processed $\hat{g}_{00}$ values alongside the ground truth. Vertical dotted lines denote the positions of ground truth peaks. Nearby local maxima of the post-processed $\hat{g}_{00}$ are the predicted peak locations.

Theorems & Definitions (2)

  • Theorem 1
  • proof : Proof of \ref{['th:bitchifc']}