MLOps Blog

K-Means Clustering Explained

12 min
8th August, 2023

Clustering was introduced in 1932 by H.E. Driver and A.L.Kroeber in their paper on “Quantitative expression of cultural relationship”. Since then this technique has taken a big leap and has been used to discover the unknown in a number of application areas eg. Healthcare.

Clustering is a type of unsupervised learning where the references need to be drawn from unlabelled datasets. Generally, it is used to capture meaningful structure, underlying processes, and grouping inherent in a dataset. In clustering, the task is to divide the population into several groups in such a way that the data points in the same groups are more similar to each other than the data points in other groups. In short, it is a collection of objects based on their similarities and dissimilarities.

Exploring Clustering Algorithms: Explanation and Use Cases

ML Experiment Tracking: What It Is, Why It Matters, and How to Implement It

With clustering, data scientists can discover intrinsic grouping among unlabelled data. Though there are no specific criteria for a good clustering and it completely depends on the user, how they want to use it for their specific needs. It can be used to find unusual data points/outliers in the data or to identify unknown properties to find a suitable grouping in the dataset.

Let’s take an example, imagine you work in a Walmart Store as a manager and would like to better understand your customers to scale up your business by using new and improved marketing strategies. It is difficult to segment your customers manually. You have some data that contains their age and purchase history, here clustering can help to group customers based on their spending. Once the customer segmentation will be done, you can define different marketing strategies for each of the groups as per target audiences.

What does clustering mean?
What does clustering mean? | Source: Author

There are many clustering algorithms grouped into different cluster models. Before choosing any algorithm for a use case, it is important to get familiar with the cluster models and if it is suitable for the use case. One more thing which should be considered while choosing any clustering algorithm is the size of your dataset. 

Datasets can contain millions of records and not all algorithms scale efficiently. K-Means is one of the most popular algorithms and it is also scale-efficient as it has a complexity of O(n). In this article, we will talk about K-Means in-depth and what makes it popular.

K-Means clustering

K-means is a centroid-based clustering algorithm, where we calculate the distance between each data point and a centroid to assign it to a cluster. The goal is to identify the K number of groups in the dataset. 

“K-means clustering is a method of vector quantization, originally from signal processing, that aims to partition n observations into k clusters in which each observation belongs to the cluster with the nearest mean, serving as a prototype of the cluster.” Source

It is an iterative process of assigning each data point to the groups and slowly data points get clustered based on similar features. The objective is to minimize the sum of distances between the data points and the cluster centroid, to identify the correct group each data point should belong to. 

Here, we divide a data space into K clusters and assign a mean value to each. The data points are placed in the clusters closest to the mean value of that cluster. There are several distance metrics available that can be used to calculate the distance. 

How does K-means work?

Let’s take an example to understand how K-means work step by step. The algorithm can be broken down into 4-5 steps. 

  1. Choosing the number of clusters 

The first step is to define the K number of clusters in which we will group the data. Let’s select K=3.

  1. Initializing centroids

Centroid is the center of a cluster but initially, the exact center of data points will be unknown so, we select random data points and define them as centroids for each cluster. We will initialize 3 centroids in the dataset.

K-means clustering - centroid
K-means clustering – centroid | Source: Author
  1. Assign data points to the nearest cluster

Now that centroids are initialized, the next step is to assign data points Xn to their closest cluster centroid Ck

K-means clustering - assign data points
K-means clustering – assign data points | Source: Author

In this step, we will first calculate the distance between data point X and centroid C using Euclidean Distance metric.

Euclidean Distance metric

And then choose the cluster for data points where the distance between the data point and the centroid is minimum. 

K-means clustering
K-means clustering | Source: Author
  1. Re-initialize centroids 

Next, we will re-initialize the centroids by calculating the average of all data points of that cluster.

Re-initialize centroids
K-means clustering
K-means clustering | Source: Author
  1. Repeat steps 3 and 4

We will keep repeating steps 3 and 4 until we have optimal centroids and the assignments of data points to correct clusters are not changing anymore.

K-means clustering
K-means clustering | Source: Author

Does this iterative process sound familiar? Well, K-means follows the same approach as Expectation-Maximization(EM). EM is an iterative method to find the maximum likelihood of parameters where the machine learning model depends on unobserved features. This approach consists of two steps Expectation(E) and Maximization(M) and iterates between these two.

For K-means, The Expectation(E) step is where each data point is assigned to the most likely cluster and the Maximization(M) step is where the centroids are recomputed using the least square optimization technique.

Centroid initialization methods 

Positioning the initial centroids can be challenging and the aim is to initialize centroids as close as possible to optimal values of actual centroids. It is recommended to use some strategies for defining initial centroids as it directly impacts the overall runtime. The traditional way is to select the centroids randomly but there are other methods as well which we will cover in the section.

  • Random Data Points

This is the traditional approach of initializing centroids where K random data points are selected and defined as centroids. As we saw in the above example, in this method each data instance in the dataset will have to be enumerated and will have to keep a record of the minimum/maximum value of each attribute. This is a time-consuming process; with increased dataset complexity the number of steps to achieve the correct centroid or correct cluster will also increase.

  • Naive Sharding

The sharding centroid initialization algorithm primarily depends on the composite summation value of all the attributes for a particular instance or row in a dataset. The idea is to calculate the composite value and then use it to sort the instances of the data. Once the data set is sorted, it is then divided horizontally into k shards.

Sorting by composite value and sharding
Sorting by composite value and sharding | Source

Finally, all the attributes from each shard will be summed and their mean will be calculated. The shard attributes mean value collection will be identified as the set of centroids that can be used for initialization. 

Centroid attribute values
Centroid attribute values | Source

Centroid initialization using sharding happens in linear time and the resultant execution time is much better than random centroid initialization.

  • K-Means++

K-means++ is a smart centroid initialization method for the K-mean algorithm. The goal is to spread out the initial centroid by assigning the first centroid randomly then selecting the rest of the centroids based on the maximum squared distance. The idea is to push the centroids as far as possible from one another.

Here are the simple steps to initialize centroids using K-means++:

  1. Randomly pick the first centroid (C1)
  2. Calculate the distance between all data points and the selected centroid
The distance between all data points and the selected centroid.

This denotes the distance of a data point xi from the farthest centroid Cj

  1. Initialize the data point xi as the new centroid
  2. Repeat steps 3 and 4 till all the defined K clusters are found

With the k-means++ initialization, the algorithm is guaranteed to find a solution that is O(log k) competitive to the optimal k-means solution.” Source

Implementing K-Means clustering in Python

Now that you are familiar with Clustering and K-means algorithms, it’s time to implement K-means using Python and see how it works on real data. 

We will be working on the Mall Visitors dataset to create customer segmentation to define a marketing strategy. The Mall Visitors sample dataset can be found on Kaggle and it summarises the spendings of around 2000 mall visitors.

Let’s clean, explore and prepare the data for the next phases where we will be segmenting customers.

Load the data and check for any missing values:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#load the dataset
customer_data = pd.read_csv("/content/Mall_Customers.csv")

#read the data

#check for null or missing values
Implementing K-Means clustering in Python.
Mall visitors dataset | Source
Implementing K-Means clustering in Python.
Mall visitors dataset | Source: Author

We will be using the Annual Income and Spending Score to find the clusters in the data. The spending score is from 1 to 100 and is assigned based on customer behavior and spending nature. 

Implementing K-Means from scratch

There are open-source libraries that provide functions for different types of clustering algorithms but before using these open-source codes just by calling a function, it is important to understand how those functions work. In this section, we will be building a K-means clustering algorithm from scratch using a random centroid initialization method. 

Let’s look at the data and see how it is distributed:

Implementing K-Means from scratch
Implementing K-Means from scratch | Source: Author

From the above scatterplot, it is difficult to tell if there is any pattern in the dataset. This is where clustering will help.

First, we will Initialize centroids randomly:

centroids = customer_data.sample(n=K)
Implementing K-Means from scratch
Implementing K-Means from scratch | Source: Author

Next, we will iterate through each centroid and data point, calculate the distance between them, find the K clusters and assign the data points to a significant cluster. This process will continue until the difference between previously defined centroids and current centroids is zero:

mask = customer_data[‘CustomerID’].isin(centroids.CustomerID.tolist())
X = customer_data[~mask]
diff = 1
   for index1,row_c in centroids.iterrows():
       for index2,row_d in XD.iterrows():

   for index,row in X.iterrows():
       for i in range(K):
           if row[i+1] < min_dist:
               min_dist = row[i+1]
   centroids_new = X.groupby(["Cluster"]).mean()[["Spending_Score","Annual_Income_(k$)"]]
   if j == 0:
       diff = (centroids_new['Spending_Score'] - centroids['Spending_Score']).sum() + (centroids_new['Annual_Income_(k$)'] - centroids['Annual_Income_(k$)']).sum()
   centroids = X.groupby(["Cluster"]).mean()[["Spending_Score","Annual_Income_(k$)"]]

Now if we will view the dataset and all the data points should be clustered accordingly:

for k in range(K):
Implementing K-Means from scratch
Implementing K-Means from scratch | Source: Author

Implementation K-means using Scikit-Learn

That was quite a lot of code and the algorithm might also need some optimization to improve its run time. To avoid enhancing the algorithm and writing it all over again, we can use open-source library functions. Scikit Learn has a clustering package that includes a K-means function which is optimized and very popular among researchers.

May be useful

Check how you can keep track of your classifiers, regressors, and k-means clustering results when using Scikit-Learn.

First, we will import the K-Means function then call the function by passing the number of clusters as an argument:

from sklearn.cluster import KMeans
km_sample = KMeans(n_clusters=3)[['Annual_Income_(k$)','Spending_Score']])

That’s all, your clustered data is ready. Let’s look at the data again:

import seaborn as sns
labels_sample = km_sample.labels_
customer_data['label'] = labels_sample
Implementation using Scikit-Learn
Implementation using Scikit-Learn | Source: Author

We were able to create customer segmentation, with only a few lines of code using Scikit-Learn. Though for this particular dataset, you can see the final clustered data is the same in both the implementations. But what did we learn about the mall customers through this segmentation:

Label 0: Savers, avg to high income but spend wisely

Label 1: Carefree, low income but spenders

Label 2: Spenders, avg to high income and spender

This was difficult to understand when we first plotted the data but now we know we have these 3 categories and Mall management can apply marketing strategies accordingly, for e.g, they might provide more saving offers to Label 0: savers group and open more lucrative shops for the Label 2: big spenders.

How to choose K?

Some factors can challenge the efficacy of the final output of the K-means clustering algorithm and one of them is finalizing the number of clusters(K). Selecting a lower number of clusters will result in underfitting while specifying a higher number of clusters can result in overfitting. Unfortunately, there is no definitive way to find the optimal number. 

The optimal number of clusters depends on the similarity measures and the parameters used for clustering. So, to find the number of clusters in the data, we need to run the k-means clustering for a range of values and compare the outcomes. At present, we don’t have any method to determine the exact accurate value of clusters K but we can estimate the value using some techniques, including Cross-validation, Elbow method, Information Criteria, the Silhouette method, and the G-means algorithm. 

Elbow method

The distance metric is one of the commonly used metrics to compare results across different K values. When the number of clusters, K is increased, the distance from centroid to data points will be decreased and will reach a point where K is the same as the number of data points. This is the reason we have been using the mean of the distance to the centroids. In the elbow method, we plot the mean distance and look for the elbow point where the rate of decrease shifts. This elbow point can be used to determine K. 

“The elbow method is a heuristic used in determining the number of clusters in a data set. The method consists of plotting the explained variation as a function of the number of clusters, and picking the elbow of the curve as the number of clusters to use.” - Source

Elbow point is used as a cutoff point in mathematical optimization to decide at which point the diminishing returns are no longer worth the additional cost. Similarly, in clustering, this is used to choose a number of clusters when adding another cluster doesn’t improve the outcomes of modeling. It is an iterative process where K-means clustering will be done on the dataset for a range of values of K as below.

  1. Perform K-means clustering with all the K values. For each K value, we compute the average distance to the centroid across all the data points:
from sklearn.cluster import KMeans
from sklearn import metrics
from scipy.spatial.distance import cdist
import numpy as np
import matplotlib.pyplot as plt

x1 = np.array([3, 1, 1, 2, 1, 6, 6, 6, 5, 6, 7, 8, 9, 8, 9, 9, 8])
x2 = np.array([5, 4, 5, 6, 5, 8, 6, 7, 6, 7, 1, 2, 1, 2, 3, 2, 3])

X = np.array(list(zip(x1, x2))).reshape(len(x1), 2)

# k means determine k
distortions = []
K = range(1,10)
for k in K:
   kmeanModel = KMeans(n_clusters=k).fit(X)
   distortions.append(sum(np.min(cdist(X, kmeanModel.cluster_centers_, 'euclidean'), axis=1)) / X.shape[0])
  1. Plot each of these points and find the point where mean distance suddenly falls (Elbow):
# Plot the elbow
plt.plot(K, distortions, 'bx-')
plt.title('The Elbow Method showing the optimal k')
Elbow method
Elbow method | Source" Author

This is probably the most popular method to determine the optimal number of clusters. Though finding elbow points can be a challenge, because in practice there may not be a sharp elbow. 

Silhouette method

Finding an Elbow point is challenging in practice but there are other techniques to determine the optimal value of K and one of them is the Silhouette Score method. 

“Silhouette refers to a method of interpretation and validation of consistency within clusters of data. The technique provides a succinct graphical representation of how well each object has been classified.” - Source

Silhouette coefficient is used to measure the quality of the clusters by checking how similar a data point is within a cluster compared to the other clusters. Silhouette analysis can be used to study the distance between the resulting clusters. This discrete measure ranges between -1 and 1:

+1 indicates that the data point is far away from the neighboring cluster and thus optimally positioned.

0 indicates either it is on or very close to the decision boundary between two neighbor clusters.

-1 indicates that the data point is assigned to the wrong cluster.

To find an optimal value for the number of clusters K, we use a silhouette plot to display a measure of how close each point in one cluster is to a point in the neighboring clusters and thus provide a way to assess parameters like the number of clusters visually. Let’s see how it works.

  1. Compute K-means clustering algorithm for a range of values.
  2. For each value of K, find the average silhouette score of data points:
from sklearn.metrics import silhouette_score

sil_avg = []
range_n_clusters = [2, 3, 4, 5, 6, 7, 8]

for k in range_n_clusters:
 kmeans = KMeans(n_clusters = k).fit(X)
 labels = kmeans.labels_
 sil_avg.append(silhouette_score(X, labels, metric = 'euclidean'))
  1. Plot the collection of silhouette scores for each value of K
  2. Select the number of clusters when the silhouette score is maximum:
plt.xlabel('Values of K')
plt.ylabel('Silhouette score')
plt.title('Silhouette analysis For Optimal k')
Elbow method
Elbow method | Source" Author

Using the above Silhouette analysis, we can choose K’s optimal value as 3 because the average silhouette score is higher and indicates that the data points are optimally positioned. 

Clustering evaluation metrics

In clustering, we don’t have any labeled data but just a set of features and the objective is to obtain high intra-cluster similarity and low inter-cluster similarity for those features. Evaluating the performance of any clustering algorithm is not as easy as calculating the number of errors or finding precision or recall like in supervised learning. Here we evaluate the outcomes based on the similarities or dissimilarities between data points. 

In the previous section, we saw how distance measures and silhouette score can help in finding the optimal value of K. So, many of these evaluation metrics can be used to find the best clustering points too for the parametric clustering algorithms. The clustering algorithm is only as good as your similarity measures. So, we need to make sure that we use appropriate similarity measures. One way is to experiment with your measures and determine which algorithm can provide more accurate similarities. 

“In theory, the clustering researcher has acquired an intuition for the clustering evaluation, but in practise the mass of data on the one hand and the subtle details of data representation and clustering algorithms on the other hand, make an intuitive judgement impossible.” - Phd Thesis, University of Stuttgart

There are several clustering evaluation metrics available and continuously evolving to help researchers with clustering. In this section, we will be discussing some of the most common and popular metrics.

Dunn index

Dunn Index is used to identify dense and well-separated groups. It is the ratio between minimum inter-cluster distance and maximum intra-cluster distance. The Dunn index can be computed as below:

Dunn index

Here d(i,j) is the distance between clusters i and j, which is the minimum of all inter-cluster distances, and d(k) is the intra-cluster distance of cluster k, which is the maximum of all intra-cluster distances. The algorithms that create clusters with a high Dunn index are more desirable as that way, clusters would be more compact and different from each other.

Silhouette score

The average Silhouette score is also used as an evaluation measure in clustering. The best silhouette score is 1 and the worst is -1. Values close to zero indicate that data points are on the boundary i.e overlapping the clusters. 


F-measure is applied to the precision and recall of pairs and is used to balance false negatives by weighing recall. 


In clustering, the common approach is to apply the F-Measure to the precision and recall of pairs, which is referred to as pair counting f-measure. We can calculate the F-measure using the below formula.


Here is chosen such that recall is considered times as important as precision. When we set as 1, It will be the harmonic mean of precision and recall. We calculate the recall and precision of the cluster for each given class i.e a set of classes should be given for the objects.

Rand index

Rand index can be used to compute how similar the clusters are to the benchmark. The value of the Rand Index can be found using the following formula.

Rand index

Here TP is the number of true positives, TN is the number of true negatives, FP is the number of false positives and FN is the number of false negatives. With this evaluation metric, we can count the number of correct pairwise assignments. 

TP is the number of data point pairs that are clustered together, in the predicted partition, and in the ground truth partition, FP is the number of data point pairs that are clustered together in the predicted partition but not in the ground truth partition. 

To learn about more evaluation metrics, you can check out the scikit learn clustering performance evaluation metrics page.

GMM (Gaussian mixture models) vs K-means clustering algorithm

The probabilistic models that identify the probability of having clusters in the overall population are considered mixture models. K-means is fast and a simple clustering method but sometimes it is not able to capture inherent heterogeneity. 

Gaussian mixture models (GMM) can identify complex patterns and club them together, which is a close representation of a real pattern within the dataset. 

Introduction to GMM

Unsupervised data in practice can contain highly spread data points and it can be difficult to manage these data points into different clusters. Gaussian Mixture Model(GMM) initializes a certain number of Gaussian distributions, and each of these represents a cluster. In GMM, we tend to group similar data points in a single distribution.

In Gaussian distribution, we try to fit data points under the bell curve and all the data points coming under the bell curve are highly correlated and can form a cluster together. The peak of the bell curve indicates the mean of the data points. GMM is based on K independent Gaussian distributions that are used for modeling K clusters.

Clustering using GMM
Clustering using GMM | Soure

Say, in a dataset, there are three different sets of data points that follow the Gaussian distribution which means there will be three bell curves. GMM will identify the probability of data points to be in any of these distributions. The probability function of GMM can be defined as:

Introduction to GMM

is the d dimensional mean vector
is the dxd covariance matrix of the gaussian

d is the number of features
X is the number of data points

GMM model fits a generative model which gives a probability distribution for a data set. To avoid overfitting or underfitting we will have to find the optimal number of distributions by evaluating the model likelihood using the cross-validation method or Akaike Information Criterion (AIC) and the Bayesian Information Criterion (BIC) method.

  • GMM uses probability distribution and K-means uses distance metrics to compute the difference between data points to segregate the data into different clusters. 
  • GMM is a soft clustering algorithm in a sense that each data point is assigned to a cluster with some degrees of uncertainty e.g. You can see in the above image that some data points have a very high probability to belong to one specific gaussian, while some points are in between two gaussians. This means that a data point can belong to more than one cluster i.e distributed in two clusters with 70/30 ratio.
  • K-means may not perform well while finding clusters of different sizes, shapes, and densities. E.g., clustering of elliptically distributed data might fail with K-means, as it’s more suited for circular distribution. In such cases, it makes sense to go ahead with GMM.

However, K-means requires much less time to discover the group for the data points than GMM does for gaussian components. K-means is popular because it is easy to apply and converge quickly. It can be used without any assumptions about the data engineering process. Due to K-means simplicity and efficiency, it is also used for image segmentation and it gives competitive results to much more complex deep neural network algorithms. 

“K-means is a reasonable default choice, at least until you figure out that the clustering step is your bottleneck in terms of overall performance.” - Analytics Dimag

Check also

Exploring Clustering Algorithms: Explanation and Use Cases

Applications & use cases

K-means can be applied to a dataset with a smaller number of dimensions, numeric and continuous data. It is suitable for scenarios where you want to group together the randomly distributed data points. Here are some of the interesting use cases where K-means can easily be used:

  • Customer segmentation

Customer segmentation is the practice of dividing a company's customers into groups that reflect similarity among customers in each group.” - Optimove

Fulfilling customers' needs is the starting point of relationship marketing and it can be improved by understanding that all customers are not the same and the same offers might not work for all. Segmenting customers based on their needs and behaviors can help companies to better market their products to the right customers. E.g., Telecom companies have a large number of users and using market or customer segmentation companies can personalize campaigns and incentives, etc. 

  • Fraud detection

The continuous development of the internet and online services is raising concern over security. Accounting for these security threats or fraudulent activities e.g. Login captivity to an Instagram account from an unusual city or hiding any sort of financial misdeed, is prevalent in the present.

Using techniques such as K-means Clustering, one can easily identify the patterns of any unusual activities. Detecting an outlier will mean a fraud event has taken place.

  • Document classification

K-Means is known for being efficient in the case of large datasets, which is why it is one of the best choices for classifying documents. Clustering documents into multiple categories based on the topics, the content and tags if available. The documents will be converted into a vector format. Then, we use term frequency to identify the common terms, and based on that we can identify similarities in the document groups. 

  • Geospatial analytics

“Outdoor ambient acoustical environments may be predicted through machine learning using geospatial features as inputs. However, collecting sufficient training data is an expensive process, particularly when attempting to improve the accuracy of models based on supervised learning methods over large, geospatially diverse regions.” - Geospatial Model

Due to these supervised algorithm limitations, we need to use unsupervised algorithms such as K-mean clustering where we can easily compare the geodiversity by clustering the data. 

  • Image segmentation

Using K-means we can find patterns in image pixels which will allow faster processing and in a more efficient way. After calculating the difference between each pixel of an image and the centroid, it is mapped to the nearest cluster. In the final output, clusters will have similar pixels grouped together.

You can find out more about K-means applications and use cases here

Advantages of K-means clustering algorithm

  • Relatively easy to understand and implement.
  • Scalable to large datasets.
  • Better computation cost.
  • Easily warm start the assignments and positions of centroids

Disadvantages of K-means clustering algorithm

  • Choosing K manually and being dependent on the initial values
  • Lacks consistent results for different values of K
  • Always tries to find circular clusters
  • Centroids get dragged due to outliers in the dataset
  • Curse of dimensionality, K is ineffective when the number of dimensions increases

Final thoughts

In this article, we discussed one of the most popular clustering algorithms. We first went through the overview of k-means and how it works, later we followed the same steps to implement it from scratch and via sklearn. We also looked at various metrics and challenges associated with it and its alternatives.

We also saw that K-means is really easy to understand and can deliver training results quickly. However, its performance can be compromised with a slight variation in the data. Clusters are assumed to be spherical and evenly sized, this may reduce the accuracy of the K-means clustering. If you want to learn more about K-mean clustering, I would recommend checking out these articles Stanford CS221 - Kmeans, Data Science Handbook and visit the scikit learn site.

Was the article useful?

Thank you for your feedback!