Table of Contents
Fetching ...

QuATON: Quantization Aware Training of Optical Neurons

Hasindu Kariyawasam, Ramith Hettiarachchi, Quansan Yang, Alex Matlock, Takahiro Nambara, Hiroyuki Kusaka, Yuichiro Kunai, Peter T C So, Edward S Boyden, Dushan Wadduwage

TL;DR

This work proposes a physics-informed quantization-aware training framework that accounts for physical constraints during the training process, leading to robust designs of state of the art optical processors using diffractive networks for multiple physics based tasks despite quantized learnable parameters.

Abstract

Optical processors, built with "optical neurons", can efficiently perform high-dimensional linear operations at the speed of light. Thus they are a promising avenue to accelerate large-scale linear computations. With the current advances in micro-fabrication, such optical processors can now be 3D fabricated, but with a limited precision. This limitation translates to quantization of learnable parameters in optical neurons, and should be handled during the design of the optical processor in order to avoid a model mismatch. Specifically, optical neurons should be trained or designed within the physical-constraints at a predefined quantized precision level. To address this critical issues we propose a physics-informed quantization-aware training framework. Our approach accounts for physical constraints during the training process, leading to robust designs. We demonstrate that our approach can design state of the art optical processors using diffractive networks for multiple physics based tasks despite quantized learnable parameters. We thus lay the foundation upon which improved optical processors may be 3D fabricated in the future.

QuATON: Quantization Aware Training of Optical Neurons

TL;DR

This work proposes a physics-informed quantization-aware training framework that accounts for physical constraints during the training process, leading to robust designs of state of the art optical processors using diffractive networks for multiple physics based tasks despite quantized learnable parameters.

Abstract

Optical processors, built with "optical neurons", can efficiently perform high-dimensional linear operations at the speed of light. Thus they are a promising avenue to accelerate large-scale linear computations. With the current advances in micro-fabrication, such optical processors can now be 3D fabricated, but with a limited precision. This limitation translates to quantization of learnable parameters in optical neurons, and should be handled during the design of the optical processor in order to avoid a model mismatch. Specifically, optical neurons should be trained or designed within the physical-constraints at a predefined quantized precision level. To address this critical issues we propose a physics-informed quantization-aware training framework. Our approach accounts for physical constraints during the training process, leading to robust designs. We demonstrate that our approach can design state of the art optical processors using diffractive networks for multiple physics based tasks despite quantized learnable parameters. We thus lay the foundation upon which improved optical processors may be 3D fabricated in the future.
Paper Structure (4 sections, 22 equations, 7 figures, 2 tables)

This paper contains 4 sections, 22 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: Representative results for performance of QuATON (PSQ-LT) compared to other QAT methods (PQ, STE, GS) used to train optical neurons. Here PSQ-LT, PQ, STE, and GS respectively stands for progressive sigmoid quantization with learnable temperature, post quantization, straight-through estimator, and gumbel softmax.
  • Figure 2: Quantization-aware training of diffractive deep neural networks (D2NNs) using progressive sigmoid quantization (PSQ):A) D2NN architecture: D2NNs consist of several diffractive layers. The input field passes through the layers and the detector captures the output intensity. B1) an example of the amplitude and phase of the input field for the MNIST dataset. The input phase contains the information of interest. B2) the ground truth output intensities for the two tasks considered. For the all-optical classification task, the detector region is divided into 10 patches corresponding to each class shown in red dotted lines. For the example shown, the area corresponding to digit 3 is lighted up, and the other areas have zero intensity. For the all-optical quantitative phase imaging (QPI) task, the ground truth output intensity is proportional to the input phase. C) the training procedure optimizing only the phase coefficients of the D2NN. During forward propagation, the raw phase weights of the $n^{\textrm{th}}$ layer $(\varphi_{n}[x,y])$ are sent through the PSQ function $Q_s(.)$. The immediate output of the layer $(E^{n}_{out})$ is obtained by modulating the input to the layer $(E^{n}_{in})$ with the soft-quantized phase coefficients as shown in the red box. During the backward propagation through the layer, the partial derivatives with respect to phase weights are computed as shown in the green box. D) the evolution of the PSQ function with the temperature parameter $(\tau)$. When $\tau$ increases from 1 to 20, the function gradually becomes closer to hard quantization while keeping the differentiability.
  • Figure 3: All-optical classification results:A1-A2) two examples of the phase of the incoming wave to the D2NN for the two datasets considered. A3-A4) output intensities for the D2NNs trained with full precision (FP) weights. B-G) classification results for the quantization-aware trained D2NNs for each of the examples. Each row named as x-Qn shows the results for dataset x$\in \{\textrm{MNIST, CIFAR10}\}$, using D2NNs trained with n-level quantized weights ($\textbf{n} \in \{2, 4, 8\}$). Each column corresponds to different QAT methods considered which are stated above row B). H) comparison of quantitative results (classification accuracy over the test set of each dataset). I-L) confusion matrices for the FP model and the best-performing methods for each quantization level for the MNIST dataset.
  • Figure 4: All-optical quantitative phase imaging results:A1)-A3) three examples of the phase of the incoming wave to the D2NN for the three datasets considered. A4-A6) output intensities $\times \pi$ for the D2NNs trained with full precision (FP) weights. B)-J) QPI results for quantization-aware trained D2NNs for each of the examples. Each row named as x-Qn shows the results for dataset x$\in \{\textrm{RBC, TINYIM, MNIST}\}$, using D2NNs trained with n-level quantized weights ($\textbf{n} \in \{4, 8, 16\}$). Each column corresponds to different QAT methods considered which are stated above row B). Note that all results are given as $\textrm{output intensity} \times \pi$. K) the comparison of the quantitative results (mean SSIM over the test set of each dataset). L1)-L3) mean absolute phase error variation against ground truth phase for RBC-Q4, TINYIM-Q8, and MNIST-Q16 cases respectively. These plots are shown for the given examples in the figure.
  • Figure S1: Mode collapse of D2NNs trained using GS for RBC dataset:A1)-A8) show eight randomly selected examples from the RBC test set. Subsequent rows show the resulting output intensities from D2NNs trained using GS with 4-level (B1-B8), 8-level (C1-C8), and 16-level (D1-D8) quantized phase weights. Although the inputs have different morphologies, each D2NN gives similar outputs to all the inputs.
  • ...and 2 more figures