Clustering with K-Means

Clustering algorithms are useful when we want to find structure in our data with little to no prior knowledge regarding what that structure might look like. The result of a simple clustering algorithm is a partitioning of the data into a set of 'clusters' which define groups of samples that are most related to each other. K-means is one of the simplest and most widely-used algorithms for this kind of analysis. It requires that you specify a K (the number of clusters) and then iteratively finds groups of samples in euclidean space that are closest to their mean. Although it is not always guaranteed to converge to the same result, it is very simple and often reliable enough to provide insight.

Of course, we can use scikit-learn to run k-means in Python.

# example modified from http://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_iris.html
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from sklearn.cluster import KMeans
from sklearn import datasets


iris = datasets.load_iris()  # load data
X = iris.data
y = iris.target

# set up a 3-D plot
fig = plt.figure(0, figsize=(4, 3))
plt.clf()
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
plt.cla()

model = KMeans(n_clusters=3)  # choose K = 3 for kmeans
model.fit(X)  # fit the data
labels = model.labels_  # cluster labels of each sample according to KMeans 

ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=labels.astype(np.float))  # scatter plot colored by cluster membership

ax.set_xlabel('Petal width')  # label x axis
ax.set_ylabel('Sepal length')  # label y axis
ax.set_zlabel('Petal length')  # label z axis

plt.show()

"""
We can also plot the ground truth clustering based on y (iris.target)
"""

# Set up the 3d plot again
fig = plt.figure(1, figsize=(4, 3))
plt.clf()
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
plt.cla()

# Plot each sample colored based on its true label
for name, label in [('Setosa', 0),
                    ('Versicolour', 1),
                    ('Virginica', 2)]:
    ax.text3D(X[y == label, 3].mean(),
              X[y == label, 0].mean() + 1.5,
              X[y == label, 2].mean(), name,
              horizontalalignment='center',
              bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))

y = np.choose(y, [1, 2, 0]).astype(np.float)  # Reorder the labels to have colors matching the cluster results
ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y)  # plot

ax.set_xlabel('Petal width')  # label x axis
ax.set_ylabel('Sepal length')  # label y axis
ax.set_zlabel('Petal length')  # label z axis

plt.show()

results matching ""

    No results matching ""