Skip to content

Commit f9e0736

Browse files
author
Phil Roth
committed
Adding a new kmeans example.
1 parent 552a80a commit f9e0736

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
====================================
3+
Demonstration of k-means assumptions
4+
====================================
5+
6+
This example is meant to illustrate situations where k-means will produce
7+
unintuitive and possibly unexpected clusters. In the first three plots, the
8+
input data does not conform to some implicit assumption that k-means makes and
9+
undesirable clusters are produced as a result. In the last plot, k-means
10+
returns intuitive clusters despite unevenly sized blobs.
11+
"""
12+
print(__doc__)
13+
14+
# Author: Phil Roth <[email protected]>
15+
# License: BSD 3 clause
16+
17+
import numpy as np
18+
import matplotlib.pyplot as plt
19+
20+
from sklearn.cluster import KMeans
21+
from sklearn.datasets import make_blobs
22+
23+
plt.figure(figsize=(12, 12))
24+
25+
n_samples = 1500
26+
random_state = 170
27+
X, y = make_blobs(n_samples=n_samples, random_state=random_state)
28+
29+
# Incorrect number of clusters
30+
y_pred = KMeans(n_clusters=2, random_state=random_state).fit_predict(X)
31+
32+
plt.subplot(221)
33+
plt.scatter(X[:, 0], X[:, 1], c=y_pred)
34+
plt.title("Incorrect Number of Blobs")
35+
36+
# Anisotropicly distributed data
37+
transformation = [[ 0.60834549, -0.63667341], [-0.40887718, 0.85253229]]
38+
X_aniso = np.dot(X, transformation)
39+
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_aniso)
40+
41+
plt.subplot(222)
42+
plt.scatter(X_aniso[:, 0], X_aniso[:, 1], c=y_pred)
43+
plt.title("Anisotropicly Distributed Blobs")
44+
45+
# Different variance
46+
X_varied, y_varied = make_blobs(n_samples=n_samples,
47+
cluster_std=[1.0, 2.5, 0.5],
48+
random_state=random_state)
49+
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_varied)
50+
51+
plt.subplot(223)
52+
plt.scatter(X_varied[:, 0], X_varied[:, 1], c=y_pred)
53+
plt.title("Unequal Variance")
54+
55+
# Unevenly sized blobs
56+
X_filtered = np.vstack((X[y == 0][:500], X[y == 1][:100], X[y == 2][:10]))
57+
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_filtered)
58+
59+
plt.subplot(224)
60+
plt.scatter(X_filtered[:, 0], X_filtered[:, 1], c=y_pred)
61+
plt.title("Unevenly Sized Blobs")
62+
63+
plt.show()

0 commit comments

Comments
 (0)