Table of Contents
Fetching ...

Llamba: Scaling Distilled Recurrent Models for Efficient Language Processing

Aviv Bick, Tobias Katsch, Nimit Sohoni, Arjun Desai, Albert Gu

TL;DR

Llamba advances efficient language modeling by distilling Transformer knowledge into subquadratic recurrent architectures (Discrete Mamba-2) using MOHAWK, achieving high throughput with far less training data (<0.1% of typical amounts). The approach preserves Llama-based architectural cues while introducing architectural distillation and on-device optimizations, enabling practical edge deployment with 4-bit quantization. Empirical results show Llamba-1B/3B/8B delivering competitive benchmark performance and superior throughput versus Transformer baselines, including strong MMLU gains relative to the teacher. This work highlights a promising path for scalable, memory-efficient language models that maintain quality while enabling private, on-device processing.

Abstract

We introduce Llamba, a family of efficient recurrent language models distilled from Llama-3.x into the Mamba architecture. The series includes Llamba-1B, Llamba-3B, and Llamba-8B, which achieve higher inference throughput and handle significantly larger batch sizes than Transformer-based models while maintaining comparable benchmark performance. Furthermore, Llamba demonstrates the effectiveness of cross-architecture distillation using MOHAWK (Bick et al., 2024), achieving these results with less than 0.1% of the training data typically used for models of similar size. To take full advantage of their efficiency, we provide an optimized implementation of Llamba for resource-constrained devices such as smartphones and edge platforms, offering a practical and memory-efficient alternative to Transformers. Overall, Llamba improves the tradeoff between speed, memory efficiency, and performance, making high-quality language models more accessible.

Llamba: Scaling Distilled Recurrent Models for Efficient Language Processing

TL;DR

Llamba advances efficient language modeling by distilling Transformer knowledge into subquadratic recurrent architectures (Discrete Mamba-2) using MOHAWK, achieving high throughput with far less training data (<0.1% of typical amounts). The approach preserves Llama-based architectural cues while introducing architectural distillation and on-device optimizations, enabling practical edge deployment with 4-bit quantization. Empirical results show Llamba-1B/3B/8B delivering competitive benchmark performance and superior throughput versus Transformer baselines, including strong MMLU gains relative to the teacher. This work highlights a promising path for scalable, memory-efficient language models that maintain quality while enabling private, on-device processing.

Abstract

We introduce Llamba, a family of efficient recurrent language models distilled from Llama-3.x into the Mamba architecture. The series includes Llamba-1B, Llamba-3B, and Llamba-8B, which achieve higher inference throughput and handle significantly larger batch sizes than Transformer-based models while maintaining comparable benchmark performance. Furthermore, Llamba demonstrates the effectiveness of cross-architecture distillation using MOHAWK (Bick et al., 2024), achieving these results with less than 0.1% of the training data typically used for models of similar size. To take full advantage of their efficiency, we provide an optimized implementation of Llamba for resource-constrained devices such as smartphones and edge platforms, offering a practical and memory-efficient alternative to Transformers. Overall, Llamba improves the tradeoff between speed, memory efficiency, and performance, making high-quality language models more accessible.

Paper Structure

This paper contains 20 sections, 5 figures, 5 tables.

Figures (5)

  • Figure 1: Average accuracy is measured over multiple benchmarks, including ARC Challenge, ARC Easy, PIQA, Winogrande, HELLASWAG, OpenBookQA, and MMLU, providing a comprehensive assessment of a model's Common Sense and Language Understanding.
  • Figure 2: Comparison of the Discrete Mamba-2 block and the Llamba architecture.
  • Figure 3: An evaluation of Llamba-8B's knowledge distillation step (MOHAWK's stage 3) across three datasets: C4, fineweb, and fineweb-edu. Each model underwent hidden-state alignment (MOHAWK's stage 2) on its respective dataset using 4 billion tokens and subsequently underwent testing with knowledge distillation on 1 billion tokens. It is observed that although all datasets yield similar outcomes across most benchmarks, MMLU shows notable improvement when utilizing fineweb-edu, unlike with fineweb and C4.
  • Figure 4: Tokens processed at different batch sizes across various models. All models were compiled using torch.compile(model, fullgraph=True) with CUDA graph compilation. We evaluated three settings: (1) Llamba-8B with gen_len=8192, (2) Llama-3.1-8B with gen_len=2048, and (3) Llama-3.1-8B with gen_len=8192. Each was tested with prompt_len=1 and batch sizes ranging from 8 to 2048. The results show that Llamba-8B achieves the highest throughput, particularly at larger batch sizes, where Transformers either slow down or run out of memory (OOM).
  • Figure 5: Comparison of on-device decoding throughput and memory consumption between Llamba-8B and Llama-3.1-8B at 4 bit quantization in MLX running on Apple Silicon M3 Pro (36GB). Llamba maintains constant high throughput and low memory consumption while the inference performance of Llama drops linearly with increasing context size.