Table of Contents
Fetching ...

JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

Anique Tahir, Lu Cheng, Huan Liu

TL;DR

A novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training and JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management is introduced, thereby enabling accelerated fine-tuning with reduced memory requirements.

Abstract

The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.

JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

TL;DR

A novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training and JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management is introduced, thereby enabling accelerated fine-tuning with reduced memory requirements.

Abstract

The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.
Paper Structure (15 sections, 1 equation, 2 figures, 2 tables)

This paper contains 15 sections, 1 equation, 2 figures, 2 tables.

Figures (2)

  • Figure 1: JORA is a library that aids in Retrieval Augmented Fine-Tuning by eliminating unnecessary boilerplate and introducing memory efficient training through tensor-parallelism and LoRA.
  • Figure 2: JORA provides a simple GUI for fine-tuning.