K-means is one of the simplest unsupervised learning algorithm used for clustering problem. Clustering is a process that finds groups of similar objects. So, in clustering our goal is to group objects based on their features similarity. K-means clustering is very easy to understand, very easy to implement and computationally efficient clustering algorithm. Now, let us see how it works.
Basic idea behind K-means is, we define k centroids, that is, one for each cluster. Here, k is the hyperparameter and we should be very careful about it. Usually, you should try range of values to determine best value of k. Where do we place them initially? Common choice is to place them as fas ar possible. Now, assign each data point to the nearest centroid. Once each data point has been assigned to one of the centroids, our next step is to recalculate k new centroids. How do we do that? We do it by moving centroid(old) to the center of the data samples that were assigned to it. And how do we do find center? We find it by taking the mean of data points in a particular cluster.
K-means clutering aims to find positions μi, i=1,2,..,k of the clusters the minimize the distance from the data points to the cluster. Mathematically, we can write this as:
If you are curious and want to know more about K-means, check this out.
from sklearn.datasets import load_digitsfrom sklearn.cluster import KMeansfrom sklearn.preprocessing import StandardScalerdigits = load_digits()dataset = digits.data#standardizess = StandardScaler()dataset = ss.fit_transform(dataset)model = KMeans(n_clusters= 10, init="k-means++", n_init=10)model.fit(dataset)
KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,n_clusters=10, n_init=10, n_jobs=1, precompute_distances='auto',random_state=None, tol=0.0001, verbose=0)
model.labels_ #assigned label(cluster) to each data point
array([0, 9, 9, ..., 9, 3, 3])
model.inertia_ #sum of distances of samples to their closest centroid
array([[ 0.00000000e+00, -3.10238752e-01, -2.06177764e-01,3.03168840e-01, -1.40191740e-01, -5.03966912e-01,-3.99585759e-01, -1.25022923e-01, -5.90775571e-02,-3.46211086e-01, 4.17672645e-01, 3.49905412e-01,2.38843406e-01, 5.24264003e-01, -2.45478180e-01,......4.37109328e-01, 5.30892528e-01, -4.63255290e-01,-5.81042693e-01, -2.06579673e-01, 1.10166433e-01,-2.07893242e-01, 3.35405862e-02, 3.21536155e-01,7.35914827e-03, -4.24178630e-01, -3.43675618e-01,-1.96007519e-01]])
import matplotlib.pyplot as plt%matplotlib inlineplt.imshow(digits.images, cmap='gray')model.predict(dataset.reshape(1,-1)) # should be 1
# lets try againplt.imshow(digits.images, cmap='gray')model.predict(dataset.reshape(1,-1)) #should be 2