Skip to content

Commit 444bf22

Browse files
committed
DOC ENH Simplify the example code; Add plots for n_clusters = 3 and 5
1 parent ae95a0c commit 444bf22

File tree

1 file changed

+60
-59
lines changed

1 file changed

+60
-59
lines changed
Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
"""
22
===============================================================================
3-
Silhouette analysis for sample data clustered using KMeans clustering algorithm
3+
Selecting the number of clusters with silhouette analysis on KMeans clustering
44
===============================================================================
55
66
Silhouette analysis can be used to study the separation distance between the
77
resulting clusters. The silhouette plot displays a measure of how close each
8-
point in one cluster is to points in the neighboring clusters. This measure has
9-
a range of [-1, 1]. Silhoette coefficients (as these values are referred to as)
10-
near +1 indicate that the sample is far away from the neighboring clusters.
11-
A value of 0 indicates that the sample is on or very close to the decision
12-
boundary between two neighboring clusters and negative values (upto -1)
13-
indicate that those samples might have been assigned to the wrong cluster.
8+
point in one cluster is to points in the neighboring clusters and thus provides
9+
a way to assess parameters like number of clusters visually. This measure has a
10+
range of [-1, 1].
11+
12+
Silhoette coefficients (as these values are referred to as) near +1 indicate
13+
that the sample is far away from the neighboring clusters. A value of 0
14+
indicates that the sample is on or very close to the decision boundary between
15+
two neighboring clusters and negative values indicate that those samples might
16+
have been assigned to the wrong cluster.
17+
18+
In this example the silhouette analysis is used to choose an optimal value for
19+
``n_clusters``. The silhouette plot shows that the ``n_clusters`` value of 3, 5
20+
and 6 are a bad pick for the given data due to the presence of clusters with
21+
below average silhouette scores and also due to wide fluctuations in the size
22+
of the silhouette plots. Silhouette analysis is more ambivalent in deciding
23+
between 2 and 4.
24+
25+
Also from the thickness of the silhouette plot the cluster size can be
26+
visualized. The silhouette plot for cluster 0 when ``n_clusters`` is equal to
27+
2, is bigger in size owing to the grouping of the 3 sub clusters into one big
28+
cluster. However when the ``n_clusters`` is equal to 4, all the plots are more
29+
or less of similar thickness and hence are of similar sizes as can be also
30+
verified from the labelled scatter plot on the right.
1431
"""
1532

1633
from __future__ import print_function
@@ -31,12 +48,12 @@
3148
X, y = make_blobs(n_samples=500,
3249
n_features=2,
3350
centers=4,
34-
cluster_std=1.0,
51+
cluster_std=1,
3552
center_box=(-10.0, 10.0),
3653
shuffle=True,
3754
random_state=1) # For reproducibility
3855

39-
range_n_clusters = [2, 4, 6]
56+
range_n_clusters = [2, 3, 4, 5, 6]
4057

4158
for n_clusters in range_n_clusters:
4259
# Create a subplot with 1 row and 2 columns
@@ -47,9 +64,9 @@
4764
# The silhouette coefficient can range from -1, 1 but in this example all
4865
# lie within [-0.1, 1]
4966
ax1.set_xlim([-0.1, 1])
50-
# The n_clusters*10 are the additional samples to demarcate the space
51-
# between silhouette plots of individual clusters.
52-
ax1.set_ylim([0, len(X) + n_clusters * 10])
67+
# The (n_clusters+1)*10 is for inserting blank space between silhouette
68+
# plots of individual clusters, to demarcate them clearly.
69+
ax1.set_ylim([0, len(X) + (n_clusters + 1) * 10])
5370

5471
# Initialize the clusterer with n_clusters value and a random generator
5572
# seed of 10 for reproducibility.
@@ -66,74 +83,58 @@
6683
# Compute the silhouette scores for each sample
6784
sample_silhouette_values = silhouette_samples(X, cluster_labels)
6885

69-
# This will hold the silhouette coefficient of all the clusters separated
70-
# by 0 samples.
71-
sorted_clustered_sample_silhouette_values = []
72-
73-
for i in np.unique(cluster_labels):
86+
y_lower = 10
87+
for i in range(n_clusters):
88+
# Aggregate the silhouette scores for samples belonging to
89+
# cluster i, and sort them
7490
ith_cluster_silhouette_values = \
7591
sample_silhouette_values[cluster_labels == i]
7692

77-
# Add the ith_cluster_silhouette_values after sorting them
7893
ith_cluster_silhouette_values.sort()
7994

80-
# The introduced 0 samples are to differentiate clearly between the
81-
# different clusters
82-
sorted_clustered_sample_silhouette_values += \
83-
ith_cluster_silhouette_values.tolist() + [0] * 10
95+
size_cluster_i = ith_cluster_silhouette_values.shape[0]
96+
y_upper = y_lower + size_cluster_i
8497

85-
x_values = np.array(sorted_clustered_sample_silhouette_values)
86-
y_range = np.arange(len(X) + 10 * n_clusters)
87-
88-
# Computing custom label coordinates for labeling the clusters
89-
# Plot the silhouette with the corresponding cluster color
90-
offset = 0
91-
for i in range(n_clusters):
92-
size_cluster_i = sum(cluster_labels == i)
98+
color = cm.spectral(float(i) / n_clusters)
99+
ax1.fill_betweenx(np.arange(y_lower, y_upper),
100+
0, ith_cluster_silhouette_values,
101+
facecolor=color, edgecolor=color, alpha=0.7)
93102

94-
x = -0.05
95-
# Label them at the middle
96-
dy = size_cluster_i
97-
ax1.text(x, offset + 0.5 * dy, str(i))
103+
# Label the silhouette plots with their cluster numbers at the middle
104+
ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
98105

99-
y_bottom = offset
100-
y_top = offset + dy
101-
102-
color = cm.spectral(float(i) / n_clusters, 1)
103-
ax1.fill_betweenx(y_range, 0, x_values,
104-
where=((y_range >= y_bottom) & (y_range < y_top)),
105-
facecolor=color, edgecolor=color)
106-
# Compute the base offset for next plot
107-
offset += size_cluster_i + 10 # 10 for the 0 samples
106+
# Compute the new y_lower for next plot
107+
y_lower = y_upper + 10 # 10 for the 0 samples
108108

109109
ax1.set_title("The silhouette plot for the various clusters.")
110110
ax1.set_xlabel("The silhouette coefficient values")
111111
ax1.set_ylabel("Cluster label")
112-
112+
113113
# The vertical line for average silhoutte score of all the values
114-
ax1.axvline(x = silhouette_avg, color = "red", linestyle = "--")
114+
ax1.axvline(x=silhouette_avg, color="red", linestyle="--")
115115

116116
ax1.set_yticks([]) # Clear the yaxis labels / ticks
117117
ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
118118

119119
# 2nd Plot showing the actual clusters formed
120-
for k in range(len(X)):
121-
color = cm.spectral(float(cluster_labels[k]) / n_clusters, 1)
122-
ax2.scatter(X[k, 0], X[k, 1], marker='.', color=color)
123-
124-
# Label the cluster centers with the cluster number for identification and
125-
# study of the corresponding silhouette plot.
126-
for c in clusterer.cluster_centers_:
127-
# Use the clusterer to know to which cluster number the current center
128-
# c belongs to
129-
i = clusterer.predict(c)[0]
130-
ax2.scatter(c[0], c[1], marker='o', c="white", alpha=1, s=200)
131-
ax2.scatter(c[0], c[1], marker="$%d$" % i, alpha=1, s=50)
120+
ax2.scatter(X[:, 0], X[:, 1], marker='.', s=30, lw=0, alpha=0.7,
121+
c=map(cm.spectral, cluster_labels.astype(float) / n_clusters))
122+
123+
# Labeling the clusters
124+
centers = clusterer.cluster_centers_
125+
# Draw white circles at cluster centers
126+
ax2.scatter(centers[:, 0], centers[:, 1],
127+
marker='o', c="white", alpha=1, s=200)
128+
129+
for i, c in enumerate(centers):
130+
ax2.scatter(c[0], c[1], marker='$%d$' % i, alpha=1, s=50)
132131

133132
ax2.set_title("The visualization of the clustered data.")
133+
ax2.set_xlabel("Feature space for the 1st feature")
134+
ax2.set_ylabel("Feature space for the 2nd feature")
135+
134136
plt.suptitle(("Silhouette analysis for KMeans clustering on sample data "
135137
"with n_clusters = %d" % n_clusters),
136138
fontsize=14, fontweight='bold')
137-
ax2.set_xlabel("Feature space for the 1st feature")
138-
ax2.set_ylabel("Feature space for the 2nd feature")
139+
139140
plt.show()

0 commit comments

Comments
 (0)