Mixture-of-Channels: Exploiting Sparse FFNs for Efficient LLMs Pre-Training and Inference
Tong Wu, Yutong He, Bin Wang, Kun Yuan
TL;DR
The paper tackles the activation-memory bottleneck in large language models by profiling activation usage under FlashAttention and introducing Mixture-of-Channels (MoC), an FFN architecture that activates only the Top-K channels per token via SwiGLU gating. MoC reduces activation memory during pre-training and speeds up decoding by loading only relevant weights, aided by hardware-aware kernels and gradient-checkpointing. Across LLaMA-family models and beyond, MoC achieves substantial memory savings with competitive perplexity and delivers end-to-end inference speedups of around 1.13×, while preserving model fidelity. The approach is orthogonal to attention-optimization methods like FlashAttention and can complement mixed-precision and other memory-saving techniques, with promising potential for MoE integration in future work.
Abstract
Large language models (LLMs) have demonstrated remarkable success across diverse artificial intelligence tasks, driven by scaling laws that correlate model size and training data with performance improvements. However, this scaling paradigm incurs substantial memory overhead, creating significant challenges for both training and inference. While existing research has primarily addressed parameter and optimizer state memory reduction, activation memory-particularly from feed-forward networks (FFNs)-has become the critical bottleneck, especially when FlashAttention is implemented. In this work, we conduct a detailed memory profiling of LLMs and identify FFN activations as the predominant source to activation memory overhead. Motivated by this, we introduce Mixture-of-Channels (MoC), a novel FFN architecture that selectively activates only the Top-K most relevant channels per token determined by SwiGLU's native gating mechanism. MoC substantially reduces activation memory during pre-training and improves inference efficiency by reducing memory access through partial weight loading into GPU SRAM. Extensive experiments validate that MoC delivers significant memory savings and throughput gains while maintaining competitive model performance.
