Skip to content

Commit 0bcbf92

Browse files
committed
Merge pull request scikit-learn#4965 from mrphilroth/4922
[MRG+1] New k-means example
2 parents f9ae0f8 + 253ae0e commit 0bcbf92

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

doc/modules/clustering.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ It suffers from various drawbacks:
156156
prior to k-means clustering can alleviate this problem
157157
and speed up the computations.
158158

159+
.. image:: ../auto_examples/cluster/images/plot_kmeans_assumptions_001.png
160+
:target: ../auto_examples/cluster/plot_kmeans_assumptions.html
161+
:align: center
162+
:scale: 50
163+
159164
K-means is often referred to as Lloyd's algorithm. In basic terms, the
160165
algorithm has three steps. The first step chooses the initial centroids, with
161166
the most basic method being to choose :math:`k` samples from the dataset
@@ -213,6 +218,8 @@ transform method of a trained model of :class:`KMeans`.
213218

214219
.. topic:: Examples:
215220

221+
* :ref:`example_cluster_plot_kmeans_assumptions.py`: Demonstrating when
222+
k-means performs intuitively and when it does not
216223
* :ref:`example_cluster_plot_kmeans_digits.py`: Clustering handwritten digits
217224

218225
.. topic:: References:
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
====================================
3+
Demonstration of k-means assumptions
4+
====================================
5+
6+
This example is meant to illustrate situations where k-means will produce
7+
unintuitive and possibly unexpected clusters. In the first three plots, the
8+
input data does not conform to some implicit assumption that k-means makes and
9+
undesirable clusters are produced as a result. In the last plot, k-means
10+
returns intuitive clusters despite unevenly sized blobs.
11+
"""
12+
print(__doc__)
13+
14+
# Author: Phil Roth <[email protected]>
15+
# License: BSD 3 clause
16+
17+
import numpy as np
18+
import matplotlib.pyplot as plt
19+
20+
from sklearn.cluster import KMeans
21+
from sklearn.datasets import make_blobs
22+
23+
plt.figure(figsize=(12, 12))
24+
25+
n_samples = 1500
26+
random_state = 170
27+
X, y = make_blobs(n_samples=n_samples, random_state=random_state)
28+
29+
# Incorrect number of clusters
30+
y_pred = KMeans(n_clusters=2, random_state=random_state).fit_predict(X)
31+
32+
plt.subplot(221)
33+
plt.scatter(X[:, 0], X[:, 1], c=y_pred)
34+
plt.title("Incorrect Number of Blobs")
35+
36+
# Anisotropicly distributed data
37+
transformation = [[ 0.60834549, -0.63667341], [-0.40887718, 0.85253229]]
38+
X_aniso = np.dot(X, transformation)
39+
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_aniso)
40+
41+
plt.subplot(222)
42+
plt.scatter(X_aniso[:, 0], X_aniso[:, 1], c=y_pred)
43+
plt.title("Anisotropicly Distributed Blobs")
44+
45+
# Different variance
46+
X_varied, y_varied = make_blobs(n_samples=n_samples,
47+
cluster_std=[1.0, 2.5, 0.5],
48+
random_state=random_state)
49+
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_varied)
50+
51+
plt.subplot(223)
52+
plt.scatter(X_varied[:, 0], X_varied[:, 1], c=y_pred)
53+
plt.title("Unequal Variance")
54+
55+
# Unevenly sized blobs
56+
X_filtered = np.vstack((X[y == 0][:500], X[y == 1][:100], X[y == 2][:10]))
57+
y_pred = KMeans(n_clusters=3, random_state=random_state).fit_predict(X_filtered)
58+
59+
plt.subplot(224)
60+
plt.scatter(X_filtered[:, 0], X_filtered[:, 1], c=y_pred)
61+
plt.title("Unevenly Sized Blobs")
62+
63+
plt.show()

0 commit comments

Comments
 (0)