Table of Contents
Fetching ...

AXLearn: Modular, Hardware-Agnostic Large Model Training

Mark Lee, Chang Lan, Tom Gunter, John Peebles, Hanzhi Zhou, Kelvin Zou, Sneha Bangalore, Chung-Cheng Chiu, Nan Du, Xianzhi Du, Philipp Dufter, Ruixuan Hou, Haoshuo Huang, Dongseong Hwang, Xiang Kong, Jinhao Lei, Tao Lei, Meng Li, Li Li, Jiarui Lu, Zhiyun Lu, Yiping Ma, David Qiu, Vivek Rathod, Senyu Tong, Zhucheng Tu, Jianyu Wang, Yongqiang Wang, Zirui Wang, Floris Weers, Sam Wiseman, Guoli Yin, Bowen Zhang, Xiyou Zhou, Danyang Zhuo, Cheng Leong, Ruoming Pang

TL;DR

This paper presents AXLearn, a production system which facilitates scalable and high-performance training of large deep learning models with a unique focus on modularity and support for hardware-agnostic training, and shares its experience in the development and operation at Apple.

Abstract

AXLearn is a production system which facilitates scalable and high-performance training of large deep learning models. Compared to other state-of-art deep learning systems, AXLearn has a unique focus on modularity and support for hardware-agnostic training. AXLearn's internal interfaces between software components follow strict encapsulation, allowing different components to be assembled to facilitate rapid model development and experimentation on different hardware infrastructure. AXLearn maintains constant complexity as we scale the components in the system, compared to linear or quadratic complexity in state-of-the-art training systems. This allows integrating features such as Rotary Position Embeddings (RoPE) into AXLearn across hundred of modules with just 10 lines of code, compared to hundreds as required in other systems. At the same time, AXLearn maintains equivalent performance compared to state-of-the-art training systems. Finally, we share our experience in the development and operation of AXLearn at Apple.

AXLearn: Modular, Hardware-Agnostic Large Model Training

TL;DR

This paper presents AXLearn, a production system which facilitates scalable and high-performance training of large deep learning models with a unique focus on modularity and support for hardware-agnostic training, and shares its experience in the development and operation at Apple.

Abstract

AXLearn is a production system which facilitates scalable and high-performance training of large deep learning models. Compared to other state-of-art deep learning systems, AXLearn has a unique focus on modularity and support for hardware-agnostic training. AXLearn's internal interfaces between software components follow strict encapsulation, allowing different components to be assembled to facilitate rapid model development and experimentation on different hardware infrastructure. AXLearn maintains constant complexity as we scale the components in the system, compared to linear or quadratic complexity in state-of-the-art training systems. This allows integrating features such as Rotary Position Embeddings (RoPE) into AXLearn across hundred of modules with just 10 lines of code, compared to hundreds as required in other systems. At the same time, AXLearn maintains equivalent performance compared to state-of-the-art training systems. Finally, we share our experience in the development and operation of AXLearn at Apple.

Paper Structure

This paper contains 20 sections, 5 figures, 4 tables.

Figures (5)

  • Figure 1: Specifying MoE transformer in AXLearn. Red components are reused from the specification of standard transformer. In AXLearn, a user script that defines MoE only needs to specify the green parts of the neural network.
  • Figure 2: AXLearn's system diagram. The blue components belong to AXLearn.
  • Figure 3: Invocation Context. Module invocations push contexts to the stack, which retrieve child states, split PRNG keys, and create child output collections. Upon returning, contexts are popped, collecting outputs into the parent collection. The context stack can be programmatically traversed to retrieve shared state, allowing features like tied weights to preserve encapsulation.
  • Figure 4: AXLearn's training throughput and MFU when scaling the number of TPUs.
  • Figure 5: AXLearn's recovery latency under a hardware failure.