Table of Contents
Fetching ...

Adjoint sharding for very long context training of state space models

Xingzi Xu, Amir Tavanaei, Kavosh Asadi, Karim Bouyarmane

TL;DR

This paper tackles the memory bottleneck of training language models on very long contexts by introducing adjoint sharding, a constant-memory gradient computation method based on the adjoint formalism. It decomposes gradient calculations into independent vector-Jacobian products (VJPs) across time and layers, enabling parallelism and substantial memory savings while preserving equivalence to backpropagation. To address the quadratic growth of VJPs with context length, the authors introduce truncated adjoint sharding, which uses a finite history window Bar T to achieve linear-like scaling, and they extend the approach to distributed and parallel training across multiple GPUs and instances. Empirical results demonstrate memory reductions up to around 3x for a 1.27B parameter state-space model at 1M context, allowing context lengths to exceed 100K tokens during training on modest multi-GPU infrastructure, highlighting practical impact for long-context fine-tuning and task-specific knowledge integration.

Abstract

Despite very fast progress, efficiently training large language models (LLMs) in very long contexts remains challenging. Existing methods fall back to training LLMs with short contexts (a maximum of a few thousands tokens in training) and use inference time techniques when evaluating on long contexts (above 1M tokens context window at inference). As opposed to long-context-inference, training on very long context input prompts is quickly limited by GPU memory availability and by the prohibitively long training times it requires on state-of-the-art hardware. Meanwhile, many real-life applications require not only inference but also training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for fact extraction, fact summarization, or fact reconciliation tasks. We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long context computationally tractable. Adjoint sharding is based on the adjoint method and computes equivalent gradients to backpropagation. We also propose truncated adjoint sharding to speed up the algorithm while maintaining performance. We provide a distributed version, and a paralleled version of adjoint sharding to further speed up training. Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training. This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.

Adjoint sharding for very long context training of state space models

TL;DR

This paper tackles the memory bottleneck of training language models on very long contexts by introducing adjoint sharding, a constant-memory gradient computation method based on the adjoint formalism. It decomposes gradient calculations into independent vector-Jacobian products (VJPs) across time and layers, enabling parallelism and substantial memory savings while preserving equivalence to backpropagation. To address the quadratic growth of VJPs with context length, the authors introduce truncated adjoint sharding, which uses a finite history window Bar T to achieve linear-like scaling, and they extend the approach to distributed and parallel training across multiple GPUs and instances. Empirical results demonstrate memory reductions up to around 3x for a 1.27B parameter state-space model at 1M context, allowing context lengths to exceed 100K tokens during training on modest multi-GPU infrastructure, highlighting practical impact for long-context fine-tuning and task-specific knowledge integration.

Abstract

Despite very fast progress, efficiently training large language models (LLMs) in very long contexts remains challenging. Existing methods fall back to training LLMs with short contexts (a maximum of a few thousands tokens in training) and use inference time techniques when evaluating on long contexts (above 1M tokens context window at inference). As opposed to long-context-inference, training on very long context input prompts is quickly limited by GPU memory availability and by the prohibitively long training times it requires on state-of-the-art hardware. Meanwhile, many real-life applications require not only inference but also training/fine-tuning with long context on specific tasks. Such applications include, for example, augmenting the context with various sources of raw reference information for fact extraction, fact summarization, or fact reconciliation tasks. We propose adjoint sharding, a novel technique that comprises sharding gradient calculation during training to reduce memory requirements by orders of magnitude, making training on very long context computationally tractable. Adjoint sharding is based on the adjoint method and computes equivalent gradients to backpropagation. We also propose truncated adjoint sharding to speed up the algorithm while maintaining performance. We provide a distributed version, and a paralleled version of adjoint sharding to further speed up training. Empirical results show the proposed adjoint sharding algorithm reduces memory usage by up to 3X with a 1.27B parameter large language model on 1M context length training. This allows to increase the maximum context length during training or fine-tuning of a 1.27B parameter model from 35K tokens to above 100K tokens on a training infrastructure composed of five AWS P4 instances.
Paper Structure (29 sections, 27 equations, 6 figures, 6 tables, 4 algorithms)

This paper contains 29 sections, 27 equations, 6 figures, 6 tables, 4 algorithms.

Figures (6)

  • Figure 1: Compared to backpropagation (red lines), adjoint sharding (blue lines) significantly reduces memory requirements at training. Showing memory cost to train $32\mathrm{M}$, $63\mathrm{M}$, $127\mathrm{M}$, $225\mathrm{M}$, and $1.27\mathrm{B}$ parameter State Space Model (SSM) with batch size $2$ and Adam optimizer on one GPU.
  • Figure 2: Adjoint sharding dissembles large models' gradient computations along the sequence dimension $t$ and the layer dimension $k$. When evaluating the gradient at time $t$, we perform $t$ vector-Jacobian products along the adjoint dimension $i$ for every layer indices $k$.
  • Figure 3: Lines in red are fine-tuning free methods and lines in blue are fine-tuning methods. Fine-tuning methods achieve better performances than fine-tuning free method but often suffer from out of memory issues chen2023extendingcontextwindowlargentkReddit2023xiao2024efficientstreaminglanguagemodelslongchat2023chen2024longloraefficientfinetuninglongcontextpeng2023yarnefficientcontextwindowzhang2024soaring4k400kextendingtworkowski2023focusedtransformercontrastivetraining. Lower values are better across all three tasks.
  • Figure 4: The adjoint states are computed sequentially backwards.
  • Figure 5: Computation schematic of $\mathrm{d} l^t/\mathrm{d} \boldsymbol{\theta}_{\boldsymbol{\mathcal{A}}_k}$, $\mathrm{d} l^t/\mathrm{d} \boldsymbol{\theta}_{\boldsymbol{\mathcal{B}}_k}$, and $\mathrm{d} l^t/\mathrm{d} \boldsymbol{\theta}_{\boldsymbol{\mathcal{C}}_k}$.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Proof 1
  • Proof 2