Skip to content

Commit b5b1c59

Browse files
committed
DOC Add silhouette analysis plot for KMeans
1 parent a413f87 commit b5b1c59

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
===============================================================================
3+
Silhouette analysis for sample data clustered using KMeans clustering algorithm
4+
===============================================================================
5+
6+
Silhouette analysis can be used to study the separation distance between the
7+
resulting clusters. The silhouette plot displays a measure of how close each
8+
point in one cluster is to points in the neighboring clusters. This measure has
9+
a range of [-1, 1]. Silhoette coefficients (as these values are referred to as)
10+
near +1 indicate that the sample is far away from the neighboring clusters.
11+
A value of 0 indicates that the sample is on or very close to the decision
12+
boundary between two neighboring clusters and negative values (upto -1)
13+
indicate that those samples might have been assigned to the wrong cluster.
14+
"""
15+
16+
from __future__ import print_function
17+
18+
from sklearn.datasets import make_blobs
19+
from sklearn.cluster import KMeans
20+
from sklearn.metrics import silhouette_samples, silhouette_score
21+
22+
import matplotlib.pyplot as plt
23+
import matplotlib.cm as cm
24+
import numpy as np
25+
26+
print(__doc__)
27+
28+
# Generating the sample data from make_blobs
29+
# This particular setting has one distict cluster and 3 clusters placed close
30+
# together.
31+
X, y = make_blobs(n_samples=500,
32+
n_features=2,
33+
centers=4,
34+
cluster_std=1.0,
35+
center_box=(-10.0, 10.0),
36+
shuffle=True,
37+
random_state=1) # For reproducibility
38+
39+
range_n_clusters = [ 2, 4, 6 ]
40+
41+
for n_clusters in range_n_clusters:
42+
# Create a subplot with 1 row and 2 columns
43+
fig, (ax1, ax2) = plt.subplots(1, 2)
44+
fig.set_size_inches(18, 7)
45+
46+
# The 1st subplot is the silhouette plot
47+
# The silhouette coefficient can range from -1, 1 but in this example all
48+
# lie within [-0.1, 1]
49+
ax1.set_xlim([-0.1, 1])
50+
# The n_clusters*10 are the additional samples to demarcate the space
51+
# between silhouette plots of individual clusters.
52+
ax1.set_ylim([0, len(X) + n_clusters*10])
53+
54+
# Initialize the clusterer with n_clusters value and a random generator
55+
# seed of 10 for reproducibility.
56+
clusterer = KMeans(n_clusters=n_clusters, random_state=10)
57+
cluster_labels = clusterer.fit_predict(X)
58+
59+
# The silhouette_score gives the average value for all the samples.
60+
# This gives a perspective into the density and separation of the formed
61+
# clusters
62+
print("For n_clusters = %d," % n_clusters,
63+
"The average silhouette_score is :",
64+
silhouette_score(X, cluster_labels))
65+
66+
# Compute the silhouette scores for each sample
67+
sample_silhouette_values = silhouette_samples(X, cluster_labels)
68+
69+
# This will hold the silhouette coefficient of all the clusters separated
70+
# by 0 samples.
71+
sorted_clustered_sample_silhouette_values = []
72+
73+
for i in np.unique(cluster_labels):
74+
ith_cluster_silhouette_values = \
75+
sample_silhouette_values[cluster_labels == i]
76+
77+
# Add the ith_cluster_silhouette_values after sorting them
78+
ith_cluster_silhouette_values.sort()
79+
80+
# The introduced 0 samples are to differentiate clearly between the
81+
# different clusters
82+
sorted_clustered_sample_silhouette_values += \
83+
ith_cluster_silhouette_values.tolist() + [0]*10
84+
85+
# Plot and fill all the sihouette plot polygons
86+
ax1.plot(sorted_clustered_sample_silhouette_values,
87+
range(len(X) + 10*n_clusters))
88+
ax1.fill_between(sorted_clustered_sample_silhouette_values,
89+
range(len(X) + 10*n_clusters))
90+
ax1.set_title("The silhouette plot for the various clusters.")
91+
ax1.set_xlabel("The silhouette coefficient values")
92+
ax1.set_ylabel("Cluster label")
93+
# A vertical line at x = 0.
94+
ax1.axvline()
95+
96+
ax1.set_yticks([]) # Clear the yaxis labels / ticks
97+
ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
98+
99+
# Computing custom label coordinates for labeling the clusters
100+
101+
offset = 0
102+
for i in range(n_clusters):
103+
size_cluster_i = sum(cluster_labels == i)
104+
105+
x = -0.05
106+
# Label them at the middle
107+
y = offset + (size_cluster_i / 2.0)
108+
ax1.text(x, y, str(i))
109+
110+
# Compute the base offset for next plot
111+
offset += size_cluster_i + 10 # 10 for the 0 samples
112+
113+
# 2nd Plot showing the actual clusters formed
114+
for k in range(len(X)):
115+
color = cm.spectral(float(cluster_labels[k]) / n_clusters, 1)
116+
ax2.scatter(X[k, 0], X[k, 1], marker='.', color=color)
117+
118+
# Label the cluster centers with the cluster number for identification and
119+
# study of the corresponding silhouette plot.
120+
for c in clusterer.cluster_centers_:
121+
# Use the clusterer to know to which cluster number the current center
122+
# c belongs to
123+
i = clusterer.predict(c)[0]
124+
ax2.scatter(c[0], c[1], marker='o', c="white", alpha = 1, s = 200)
125+
ax2.scatter(c[0], c[1], marker="$%d$" % i, alpha = 1, s = 50)
126+
127+
ax2.set_title("The visualization of the clustered data.")
128+
plt.suptitle(("Silhouette analysis for KMeans clustering on sample data "
129+
"with n_clusters = %d" % n_clusters),
130+
fontsize = 14, fontweight = 'bold')
131+
ax2.set_xlabel("Feature space for the 1st feature")
132+
ax2.set_ylabel("Feature space for the 2nd feature")
133+
plt.show()

0 commit comments

Comments
 (0)