An Approach Towards Learning K-means-friendly Deep Latent Representation
Debapriya Roy
TL;DR
This work tackles clustering of high-dimensional data by learning a clustering-friendly latent representation and stable K-means centers. It introduces a two-step scheme: first, finetune a pretrained autoencoder with a reconstruction loss and a differentiable centering loss to shape embeddings toward the cluster space; second, reinitialize centroids by applying classical K-means on the latent space after each epoch. The approach uses a differentiable centering function $G_{K,f}$ with parameter $\alpha$ to balance reconstruction and clustering, and demonstrates improved ACC and NMI across multiple datasets. The centroid reinitialization strategy is shown to be crucial for performance, offering a practical alternative to fully joint or fully pretraining-based deep clustering methods.
Abstract
Clustering is a long-standing problem area in data mining. The centroid-based classical approaches to clustering mainly face difficulty in the case of high dimensional inputs such as images. With the advent of deep neural networks, a common approach to this problem is to map the data to some latent space of comparatively lower dimensions and then do the clustering in that space. Network architectures adopted for this are generally autoencoders that reconstruct a given input in the output. To keep the input in some compact form, the encoder in AE's learns to extract useful features that get decoded at the reconstruction end. A well-known centroid-based clustering algorithm is K-means. In the context of deep feature learning, recent works have empirically shown the importance of learning the representations and the cluster centroids together. However, in this aspect of joint learning, recently a continuous variant of K-means has been proposed; where the softmax function is used in place of argmax to learn the clustering and network parameters jointly using stochastic gradient descent (SGD). However, unlike K-means, where the input space stays constant, here the learning of the centroid is done in parallel to the learning of the latent space for every batch of data. Such batch updates disagree with the concept of classical K-means, where the clustering space remains constant as it is the input space itself. To this end, we propose to alternatively learn a clustering-friendly data representation and K-means based cluster centers. Experiments on some benchmark datasets have shown improvements of our approach over the previous approaches.
