A tutorial covering K-Means Clustering, complete with code and interactive visualizations. Made by Saurav Maheshkar using Weights & Biases

Unsupervised Learning involves the use of models to look for patterns in datasets with little to no human supervision. This is in clear contrast to Supervised Learning which involves the use of labels along with data points to learn some mapping from the input space to the target space.

Clustering is one of the most common applications of Unsupervised Learning which essentially involves grouping similar items in clusters. It's usually used as a data exploration technique and can come in extremely handy for simpler applications like regression.

One of the most common and widely used methods for clustering is the K-Means Algorithm. Let's go over it!

Okay. We'll start with the famous iris dataset. Let's say that you have feature labels, maybe a mixture of petal lengths, color, leaf size, etc. and you want to sort them into groups. We might not know the exact number of groups we want to split our data into but we have some metric in mind. K-Means can be a simple yet efficient approach for this problem.

Specifically, this dataset consists of 3 different types of irisesβ Setosa, Versicolour, and Virginica with features corresponding to the Sepal Length, Sepal Width, Petal Length and Petal Width.

The K-Means algorithm starts with some random data-points and sets them as the initial clusters centers (also known as centroids). Then, we take any other point in the dataset and calculate the distance from this point to our randomly selected cluster centers, assigning it to the "closest" cluster, we iterate this process for all remaining data points. After iterating through the entire dataset, each point now belongs to a cluster.

In the next step, we repeat the aforementioned process but this time we take the mean of all points in the cluster as our new center and reallocate points if need be. Now we can iterate this process until the centers don't change, thereby achieving our best "fit".

Notice, here that by taking the mean of all values in a particular cluster, our centroids are now not actual data points but rather some value in the same "space".

So far we've been through the basics of K-Means clustering but we haven't really discussed how to know if our model has given us the best possible fit. It can happen that due to the random selection of centers in the beginning we end up with non-optimal cluster centers. That's why we usually iterate this process multiple times, until we get the best "fit."

But what constitutes a good fit ? One way to determine the fit is to calculate the sum of the variations of all clusters and aim to minimize this metric over the iterations. This metric is commonly known as the inertia or within-cluster sum-of-squares. This way we'll get some score corresponding to each iteration and then choose the clusters which correspond to the lowest variation within each cluster. The mathematical notation corresponding to this statement can be written as:

\huge \sum_{i=0}^{n} \,\,\, \underset{\mu_j \in \, \mathcal{C}}{min} (\,|| \, x_i - \mu_j \, ||\,)^2

where, we loop over all n datapoints in the dataset, while minimizing the variance corresponding to each mean \mu for every cluster \mathcal{C}.

Let's go over a code sample on how you can use sklearn.clusters with wandb to easily visualize the clusters.

`# Import from sklearn.cluster import KMeans# Define the Estimatorest = KMeans(n_clusters = config.n_clusters, random_state = config.seed)# Compute the Clustersest.fit(X)# Plot the Clusters to W&Bwandb.sklearn.plot_clusterer(est, X, cluster_labels = est.fit_predict(X), labels=config.labels, model_name='KMeans')`

And that wraps up our short post on K-Means Clustering and how you can use the KMeans from sklearn on an example dataset. To see the full suite of wandb features please check out this short 5 minutes guide. If you want more reports convering the math and "from-scratch" code implementations let us know in the comments down below or on our forum!

Check out these other reports on Fully Connected covering other fundamental concepts like Linear Regression and Decision Trees.

Report Gallery