Table of Contents
Fetching ...

DAC-JAX: A JAX Implementation of the Descript Audio Codec

David Braun

TL;DR

This work presents an open-source JAX implementation of the Descript Audio Codec (DAC), matching the PyTorch reference in token sequences and reconstructed audio while leveraging Flax, Optax, and related tools. It enables chunked, memory-bounded compression/decompression and compares performance across consumer and cluster GPUs, showing JAX advantages for small-to-moderate chunk sizes on typical hardware. The authors discuss integration with Faust-JAX, Penzai for intermediate-value inspection, and data-flow tooling (ArgBind, AUX, CLU) to support training and evaluation in a JAX-centric ecosystem. Practical impact includes faster consumer-hardware performance in many settings and a pathway toward end-to-end JAX-based music-generation pipelines, albeit with caveats in large-chunk scenarios on clusters. The work highlights cross-framework tradeoffs, tooling gaps, and avenues for future research in LM-based music modeling and audio synthesis with JAX.

Abstract

We present an open-source implementation of the Descript Audio Codec (DAC) using Google's JAX ecosystem of Flax, Optax, Orbax, AUX, and CLU. Our codebase enables the reuse of model weights from the original PyTorch DAC, and we confirm that the two implementations produce equivalent token sequences and decoded audio if given the same input. We provide a training and fine-tuning script which supports device parallelism, although we have only verified it using brief training runs with a small dataset. Even with limited GPU memory, the original DAC can compress or decompress a long audio file by processing it as a sequence of overlapping "chunks." We implement this feature in JAX and benchmark the performance on two types of GPUs. On a consumer-grade GPU, DAC-JAX outperforms the original DAC for compression and decompression at all chunk sizes. However, on a high-performance, cluster-based GPU, DAC-JAX outperforms the original DAC for small chunk sizes but performs worse for large chunks.

DAC-JAX: A JAX Implementation of the Descript Audio Codec

TL;DR

This work presents an open-source JAX implementation of the Descript Audio Codec (DAC), matching the PyTorch reference in token sequences and reconstructed audio while leveraging Flax, Optax, and related tools. It enables chunked, memory-bounded compression/decompression and compares performance across consumer and cluster GPUs, showing JAX advantages for small-to-moderate chunk sizes on typical hardware. The authors discuss integration with Faust-JAX, Penzai for intermediate-value inspection, and data-flow tooling (ArgBind, AUX, CLU) to support training and evaluation in a JAX-centric ecosystem. Practical impact includes faster consumer-hardware performance in many settings and a pathway toward end-to-end JAX-based music-generation pipelines, albeit with caveats in large-chunk scenarios on clusters. The work highlights cross-framework tradeoffs, tooling gaps, and avenues for future research in LM-based music modeling and audio synthesis with JAX.

Abstract

We present an open-source implementation of the Descript Audio Codec (DAC) using Google's JAX ecosystem of Flax, Optax, Orbax, AUX, and CLU. Our codebase enables the reuse of model weights from the original PyTorch DAC, and we confirm that the two implementations produce equivalent token sequences and decoded audio if given the same input. We provide a training and fine-tuning script which supports device parallelism, although we have only verified it using brief training runs with a small dataset. Even with limited GPU memory, the original DAC can compress or decompress a long audio file by processing it as a sequence of overlapping "chunks." We implement this feature in JAX and benchmark the performance on two types of GPUs. On a consumer-grade GPU, DAC-JAX outperforms the original DAC for compression and decompression at all chunk sizes. However, on a high-performance, cluster-based GPU, DAC-JAX outperforms the original DAC for small chunk sizes but performs worse for large chunks.
Paper Structure (17 sections, 3 figures, 2 tables)

This paper contains 17 sections, 3 figures, 2 tables.

Figures (3)

  • Figure 1: Log-log plot of the data from Table \ref{['tab:execution-2080']}.
  • Figure 2: Log-log plot of the data from Table \ref{['tab:execution-l40']}.
  • Figure 3: Log-log plot of the unitless ratio of hop size to execution time.