Table of Contents
Fetching ...

FlashKAT: Understanding and Addressing Performance Bottlenecks in the Kolmogorov-Arnold Transformer

Matthew Raffel, Lizhong Chen

TL;DR

This work addresses why KAT trains slowly despite FLOPs being near-MLP levels, revealing that backward-pass memory stalls—driven by gradient accumulation with atomic adds—are the primary bottleneck. It analyzes the gradient structure of the group-wise rational PAU activations and demonstrates that conventional backward kernels incur heavy global memory traffic. The authors introduce FlashKAT, a memory-aware kernel that restructures gradient accumulation to use block-level reductions and shared memory, eliminating excessive atomic Adds and drastically reducing memory stalls. Empirical results show FlashKAT achieves up to ~140x speedups in the backward pass and up to ~86x faster training on ImageNet-1K while reducing gradient rounding errors, bringing KAT closer to ViT-like training speeds and stability.

Abstract

The Kolmogorov-Arnold Network (KAN) has been gaining popularity as an alternative to the multi-layer perceptron (MLP) with its increased expressiveness and interpretability. Even so, the KAN suffers from being orders of magnitude slower due to its increased computational cost and training instability, limiting its applicability to larger-scale tasks. Recently, the Kolmogorov-Arnold Transformer (KAT) has been proposed, which can achieve FLOPs similar to the traditional Transformer with MLPs by leveraging Group-Rational KAN (GR-KAN). Unfortunately, despite the comparable FLOPs, our testing reveals that the KAT is still 123x slower in training speeds, indicating that there are other performance bottlenecks beyond FLOPs. In this paper, we conduct a series of experiments to understand the root cause of the slowdown in KAT. We uncover that the slowdown can be isolated to memory stalls, linked more specifically to inefficient gradient accumulations in the backward pass of GR-KAN. To address this memory bottleneck, we propose FlashKAT, which minimizes accesses to slow memory and the usage of atomic adds through a restructured kernel. Evaluations demonstrate that FlashKAT can achieve a training speedup of 86.5x compared with the state-of-the-art KAT, while reducing rounding errors in the computation of the gradients.

FlashKAT: Understanding and Addressing Performance Bottlenecks in the Kolmogorov-Arnold Transformer

TL;DR

This work addresses why KAT trains slowly despite FLOPs being near-MLP levels, revealing that backward-pass memory stalls—driven by gradient accumulation with atomic adds—are the primary bottleneck. It analyzes the gradient structure of the group-wise rational PAU activations and demonstrates that conventional backward kernels incur heavy global memory traffic. The authors introduce FlashKAT, a memory-aware kernel that restructures gradient accumulation to use block-level reductions and shared memory, eliminating excessive atomic Adds and drastically reducing memory stalls. Empirical results show FlashKAT achieves up to ~140x speedups in the backward pass and up to ~86x faster training on ImageNet-1K while reducing gradient rounding errors, bringing KAT closer to ViT-like training speeds and stability.

Abstract

The Kolmogorov-Arnold Network (KAN) has been gaining popularity as an alternative to the multi-layer perceptron (MLP) with its increased expressiveness and interpretability. Even so, the KAN suffers from being orders of magnitude slower due to its increased computational cost and training instability, limiting its applicability to larger-scale tasks. Recently, the Kolmogorov-Arnold Transformer (KAT) has been proposed, which can achieve FLOPs similar to the traditional Transformer with MLPs by leveraging Group-Rational KAN (GR-KAN). Unfortunately, despite the comparable FLOPs, our testing reveals that the KAT is still 123x slower in training speeds, indicating that there are other performance bottlenecks beyond FLOPs. In this paper, we conduct a series of experiments to understand the root cause of the slowdown in KAT. We uncover that the slowdown can be isolated to memory stalls, linked more specifically to inefficient gradient accumulations in the backward pass of GR-KAN. To address this memory bottleneck, we propose FlashKAT, which minimizes accesses to slow memory and the usage of atomic adds through a restructured kernel. Evaluations demonstrate that FlashKAT can achieve a training speedup of 86.5x compared with the state-of-the-art KAT, while reducing rounding errors in the computation of the gradients.

Paper Structure

This paper contains 21 sections, 8 equations, 3 figures, 8 tables, 2 algorithms.

Figures (3)

  • Figure 1: Comparison of training time (Fwd+Bwd) for ViT and KAT.
  • Figure 2: The warp‐state statistics for the KAT group-wise rational function backward pass. "Computing - Selected" is where the warp does useful computation; all the others are memory stall (MS) states.
  • Figure 3: The warp‐state statistics for the FlashKAT group-wise rational function backward pass. "Computing - Selected" is where the warp does useful computation; all the others are memory stall (MS) states.