Table of Contents
Fetching ...

Deep Learning as Ricci Flow

Anthony Baptista, Alessandro Barp, Tapabrata Chakraborti, Chris Harbron, Ben D. MacArthur, Christopher R. S. Banerji

TL;DR

The paper investigates how deep neural networks transform data geometry during classification by defining a discrete Ricci-flow-inspired framework. It constructs layerwise $k$-NN graphs and computes Forman-Ricci curvature to quantify curvature-driven changes, introducing a global Ricci network flow metric (the Ricci coefficient) that captures the alignment between layerwise distance changes and curvature. Across synthetic and real datasets, most well-trained DNNs exhibit negative Ricci coefficients at appropriate scales, with stronger Ricci-flow-like dynamics correlating with higher accuracy, and the optimal scale $k$ depending on the data. The work suggests that differential-geometry tools can aid explainability and architecture design, enabling dataset-aware model selection and potential uncertainty estimation through geometric metrics.

Abstract

Deep neural networks (DNNs) are powerful tools for approximating the distribution of complex data. It is known that data passing through a trained DNN classifier undergoes a series of geometric and topological simplifications. While some progress has been made toward understanding these transformations in neural networks with smooth activation functions, an understanding in the more general setting of non-smooth activation functions, such as the rectified linear unit (ReLU), which tend to perform better, is required. Here we propose that the geometric transformations performed by DNNs during classification tasks have parallels to those expected under Hamilton's Ricci flow - a tool from differential geometry that evolves a manifold by smoothing its curvature, in order to identify its topology. To illustrate this idea, we present a computational framework to quantify the geometric changes that occur as data passes through successive layers of a DNN, and use this framework to motivate a notion of `global Ricci network flow' that can be used to assess a DNN's ability to disentangle complex data geometries to solve classification problems. By training more than $1,500$ DNN classifiers of different widths and depths on synthetic and real-world data, we show that the strength of global Ricci network flow-like behaviour correlates with accuracy for well-trained DNNs, independently of depth, width and data set. Our findings motivate the use of tools from differential and discrete geometry to the problem of explainability in deep learning.

Deep Learning as Ricci Flow

TL;DR

The paper investigates how deep neural networks transform data geometry during classification by defining a discrete Ricci-flow-inspired framework. It constructs layerwise -NN graphs and computes Forman-Ricci curvature to quantify curvature-driven changes, introducing a global Ricci network flow metric (the Ricci coefficient) that captures the alignment between layerwise distance changes and curvature. Across synthetic and real datasets, most well-trained DNNs exhibit negative Ricci coefficients at appropriate scales, with stronger Ricci-flow-like dynamics correlating with higher accuracy, and the optimal scale depending on the data. The work suggests that differential-geometry tools can aid explainability and architecture design, enabling dataset-aware model selection and potential uncertainty estimation through geometric metrics.

Abstract

Deep neural networks (DNNs) are powerful tools for approximating the distribution of complex data. It is known that data passing through a trained DNN classifier undergoes a series of geometric and topological simplifications. While some progress has been made toward understanding these transformations in neural networks with smooth activation functions, an understanding in the more general setting of non-smooth activation functions, such as the rectified linear unit (ReLU), which tend to perform better, is required. Here we propose that the geometric transformations performed by DNNs during classification tasks have parallels to those expected under Hamilton's Ricci flow - a tool from differential geometry that evolves a manifold by smoothing its curvature, in order to identify its topology. To illustrate this idea, we present a computational framework to quantify the geometric changes that occur as data passes through successive layers of a DNN, and use this framework to motivate a notion of `global Ricci network flow' that can be used to assess a DNN's ability to disentangle complex data geometries to solve classification problems. By training more than DNN classifiers of different widths and depths on synthetic and real-world data, we show that the strength of global Ricci network flow-like behaviour correlates with accuracy for well-trained DNNs, independently of depth, width and data set. Our findings motivate the use of tools from differential and discrete geometry to the problem of explainability in deep learning.
Paper Structure (12 sections, 12 equations, 4 figures)

This paper contains 12 sections, 12 equations, 4 figures.

Figures (4)

  • Figure 1: Deep learning and Ricci flow.A. An example of deep learning. The structure of two non-linearly separable, entwined, manifolds is learned by a deep neural network (DNN). A test set, consisting of random samples drawn from the two manifolds, is passed through the trained DNN and the output of each layer is visualised via its first two principal components. As the test set passes through the layers of the trained DNN, irregularities in the geometry of the data are smoothed, and the two manifolds are separated. B. An example of Ricci flow. An irregular manifold, consisting of two generally positively curved regions joined by a region of negative curvature, evolves according to a Ricci flow. The irregularities on the positively curved regions are smoothed and the negatively curved region expands, separating them. C. When represented as a $k$-nearest neighbour graph, dense sets of points form positively curved clique-like structures which are drawn together under discrete Ricci flow; sparse sets of points form negatively curved tree-like structures, which are separated by Ricci flow.
  • Figure 2: Data sets for binary classification and DNN architectures trained.A. Three synthetic data sets A, B and C describe binary classification problems with different degrees of geometric and topological entanglement. We also considered two binary classification problems from the MNIST data set: distinguishing similar looking numbers ('1' vs '7' and '6' vs '8'). Finally, we considered two binary classification problems from the fashion MNIST data set: distinguishing similar looking items of clothing ('sandals' vs 'ankle boots' and 'shirts' vs 'coats') B. For each problem, three different DNN widths were considered: narrow (25 nodes wide); wide (50 nodes wide); and bottleneck, as shown. For each choice of width, two depths were trained: shallow (5 hidden layers) and deep (11 hidden layers).
  • Figure 3: Ricci flow-like behaviour and the number of nearest neighbours $k$.A. Heatmap of aggregated Ricci coefficients, computed across $\geq 25$ DNNs of a given width and depth trained on a given data set, for various values of $k$ evaluated for synthetic test data sets A, B and C of 1,000 points. B. Heatmap of aggregated Ricci coefficients, computed across $\geq 25$ DNNs of a given width and depth trained on a given data set, for various values of $k$ evaluated for binary comparisons in the MNIST and fMNIST data sets with test data sets of $\sim 2,000$ points. Black boxes outline the value of $k$ yielding the most negative aggregated Ricci coef. For both heatmaps the dendrogram shows the results of hierarchical clustering using the aggregated Ricci coefficients as a feature. We see that higher values of $k$ are required to observe Ricci flow-like behaviour in the MNIST data sets compared to the synthetic sets.
  • Figure 4: Ricci flow-like behaviour has different implications for different data sets Scatter plots display total curvature at layer $l$ against total change in distance between point pairs at layers $l+1$ and $l$, for each data set and various DNN architectures. The Ricci coefficient ($\rho$) for each DNN is presented on each plot. For the synthetic data sets, total distance change between points increases through the layers and curvature drops, implying a separation of points from different classes. Conversely for MNIST and fMNIST total distance change decreases through the layers and curvature increases, implying an aggregation of points from the same class.