Table of Contents
Fetching ...

OptINC: Optical In-Network-Computing for Scalable Distributed Learning

Sijie Fei, Grace Li Zhang, Bing Li, Ulf Schlichtmann

Abstract

Distributed learning is widely used for training large models on large datasets by distributing parts of the model or dataset across multiple devices and aggregating the computed results for subsequent computations or parameter updates. Existing communication algorithms for distributed learning such as ring all-reduce result in heavy communication overhead between servers. Since communication in large-scale systems uses optical fibers, we propose an Optical In-Network-Computing (OptINC) architecture to offload the computation in servers onto the optical interconnects. To execute gradient averaging and quantization in the optical domain, we incorporate optical devices such as Mach-Zehnder-Interferometers (MZIs) into the interconnects. Such a de facto optical neural network (ONN) can effectively reduce the communication overhead in existing distributed training solutions. To reduce dataset complexity for training this neural network, a preprocessing algorithm implemented in the optical domain is also proposed. Hardware cost is lowered by approximating the weight matrices of the optical neural network with unitary and diagonal matrices, while the accuracy is maintained by a proposed hardware-aware training algorithm. The proposed solution was evaluated on real distributed learning tasks, including ResNet50 on CIFAR-100, and a LLaMA-based network on Wikipedia-1B. In both cases, the proposed framework can achieve comparable training accuracy to the ring all-reduce baseline, while eliminating communication overhead.

OptINC: Optical In-Network-Computing for Scalable Distributed Learning

Abstract

Distributed learning is widely used for training large models on large datasets by distributing parts of the model or dataset across multiple devices and aggregating the computed results for subsequent computations or parameter updates. Existing communication algorithms for distributed learning such as ring all-reduce result in heavy communication overhead between servers. Since communication in large-scale systems uses optical fibers, we propose an Optical In-Network-Computing (OptINC) architecture to offload the computation in servers onto the optical interconnects. To execute gradient averaging and quantization in the optical domain, we incorporate optical devices such as Mach-Zehnder-Interferometers (MZIs) into the interconnects. Such a de facto optical neural network (ONN) can effectively reduce the communication overhead in existing distributed training solutions. To reduce dataset complexity for training this neural network, a preprocessing algorithm implemented in the optical domain is also proposed. Hardware cost is lowered by approximating the weight matrices of the optical neural network with unitary and diagonal matrices, while the accuracy is maintained by a proposed hardware-aware training algorithm. The proposed solution was evaluated on real distributed learning tasks, including ResNet50 on CIFAR-100, and a LLaMA-based network on Wikipedia-1B. In both cases, the proposed framework can achieve comparable training accuracy to the ring all-reduce baseline, while eliminating communication overhead.

Paper Structure

This paper contains 10 sections, 10 equations, 5 figures, 1 table.

Figures (5)

  • Figure 1: The ring all-reduce algorithm in distributed training with four servers connected to a switch, forming a logical ring topology.
  • Figure 2: The interleaving MZI array for a 4$\times$4 unitary matrix (adopted from onn5) where an MZI consists of two DCs and two PSs.
  • Figure 3: The proposed OptINC architecture connecting $N$ servers, $S_1$ to $S_N$, each with $M$ full-duplex optical transceivers. The system consists of three components: a preprocessing unit $\mathbf{P}$, an ONN $\boldsymbol{f}_\theta$, and a splitting unit $\mathbf{T}$.
  • Figure 4: Weight matrix $W$ can be partitioned to square submatrices $W_s$ in two ways, horizontally or vertically.
  • Figure 5: The cascading topology with OptINCs in two levels, supporting up to $N^2$ servers.