Table of Contents
Fetching ...

Flash-KMeans: Fast and Memory-Efficient Exact K-Means

Shuo Yang, Haocheng Xi, Yilong Zhao, Muyang Li, Xiaoze Fan, Jintao Zhang, Han Cai, Yujun Lin, Xiuyu Li, Kurt Keutzer, Song Han, Chenfeng Xu, Ion Stoica

TL;DR

This work revisits this classical algorithm under the lens of modern AI system design and enables k-means as an online primitive, and proposes flash-kmeans, an IO-aware and contention-free implementation for modern GPU workloads.

Abstract

$k$-means has historically been positioned primarily as an offline processing primitive, typically used for dataset organization or embedding preprocessing rather than as a first-class component in online systems. In this work, we revisit this classical algorithm under the lens of modern AI system design and enable $k$-means as an online primitive. We point out that existing GPU implementations of $k$-means remain fundamentally bottlenecked by low-level system constraints rather than theoretical algorithmic complexity. Specifically, the assignment stage suffers from a severe IO bottleneck due to the massive explicit materialization of the $N \times K$ distance matrix in High Bandwidth Memory (HBM). Simultaneously, the centroid update stage is heavily penalized by hardware-level atomic write contention caused by irregular, scatter-style token aggregations. To bridge this performance gap, we propose flash-kmeans, an IO-aware and contention-free $k$-means implementation for modern GPU workloads. Flash-kmeans introduces two core kernel-level innovations: (1) FlashAssign, which fuses distance computation with an online argmin to completely bypass intermediate memory materialization; (2) sort-inverse update, which explicitly constructs an inverse mapping to transform high-contention atomic scatters into high-bandwidth, segment-level localized reductions. Furthermore, we integrate algorithm-system co-designs, including chunked-stream overlap and cache-aware compile heuristics, to ensure practical deployability. Extensive evaluations on NVIDIA H200 GPUs demonstrate that flash-kmeans achieves up to 17.9$\times$ end-to-end speedup over best baselines, while outperforming industry-standard libraries like cuML and FAISS by 33$\times$ and over 200$\times$, respectively.

Flash-KMeans: Fast and Memory-Efficient Exact K-Means

TL;DR

This work revisits this classical algorithm under the lens of modern AI system design and enables k-means as an online primitive, and proposes flash-kmeans, an IO-aware and contention-free implementation for modern GPU workloads.

Abstract

-means has historically been positioned primarily as an offline processing primitive, typically used for dataset organization or embedding preprocessing rather than as a first-class component in online systems. In this work, we revisit this classical algorithm under the lens of modern AI system design and enable -means as an online primitive. We point out that existing GPU implementations of -means remain fundamentally bottlenecked by low-level system constraints rather than theoretical algorithmic complexity. Specifically, the assignment stage suffers from a severe IO bottleneck due to the massive explicit materialization of the distance matrix in High Bandwidth Memory (HBM). Simultaneously, the centroid update stage is heavily penalized by hardware-level atomic write contention caused by irregular, scatter-style token aggregations. To bridge this performance gap, we propose flash-kmeans, an IO-aware and contention-free -means implementation for modern GPU workloads. Flash-kmeans introduces two core kernel-level innovations: (1) FlashAssign, which fuses distance computation with an online argmin to completely bypass intermediate memory materialization; (2) sort-inverse update, which explicitly constructs an inverse mapping to transform high-contention atomic scatters into high-bandwidth, segment-level localized reductions. Furthermore, we integrate algorithm-system co-designs, including chunked-stream overlap and cache-aware compile heuristics, to ensure practical deployability. Extensive evaluations on NVIDIA H200 GPUs demonstrate that flash-kmeans achieves up to 17.9 end-to-end speedup over best baselines, while outperforming industry-standard libraries like cuML and FAISS by 33 and over 200, respectively.
Paper Structure (15 sections, 3 equations, 5 figures, 3 algorithms)

This paper contains 15 sections, 3 equations, 5 figures, 3 algorithms.

Figures (5)

  • Figure 1: Overview of flash-kmeans and Performance Breakdown. (a) Inspired by IO-aware attention mechanisms, FlashAssign streams data blocks from HBM to SRAM, fusing distance computation with an online argmin operator to completely bypass the materialization of the massive $N \times K$ distance matrix. (b) Compared to standard k-means implementations, flash-kmeans drastically compresses both the assignment IO bottleneck and the update synchronization bottleneck.
  • Figure 2: Illustration of the centroid update stage and atomic contention. (a) Standard scatter-style update: Tokens are directly scattered to their assigned centroids. The highly irregular mapping leads to severe write-side atomic contention when multiple threads update the same centroid simultaneously. (b) Sort-inverse update: flash-kmeans first sorts the tokens by their cluster IDs and constructs an inverse mapping. This transforms the unstructured scatter into regularized, segment-level localized reductions. (c) Execution timeline comparison: The timeline reveals that the standard approach stalls frequently due to atomic lock contention on HBM, whereas Sort-Inverse Update issues contention-free memory writes, significantly hiding latency and accelerating the reduction phase.
  • Figure 3: End-to-End Latency of flash-kmeans compared to standard baselines. We group the evaluation into four representative regimes. The y-axis is presented in log scale to accommodate magnitude differences. Red × marks denote Out-Of-Memory failures (e.g., standard PyTorch fails to explicitly materialize the $N \times K$ distance matrix in Large $K$ workloads). flash-kmeans delivers consistent and substantial speedups across diverse shapes and handles extreme limits gracefully.
  • Figure 4: Kernel-level latency breakdown. Latency comparison of our custom kernels versus standard implementations across diverse extreme workloads ($D=128$). Left: FlashAssign completely removes HBM distance materialization, scaling up to a 21.2$\times$ speedup. Right: Sort-Inverse Update replaces per-token scatter atomic adds with sorted localized merges, eliminating atomic contention and achieving up to a 6.3$\times$ speedup.
  • Figure 5: Effectiveness of the cache-aware compile heuristic. Left: Compilation and configuration search time (log scale). The heuristic slashes tuning overhead by up to 175$\times$, entirely bypassing the severe compilation bottleneck of exhaustive auto-tuning. Right: Kernel iteration latency. The heuristic seamlessly matches the optimal performance found by exhaustive search across various shapes, ensuring near-zero runtime degradation.