Skip to content

Commit c83adb5

Browse files
committed
BUG: check n_clusters == len(cluster_centers_)
1 parent 134df77 commit c83adb5

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

sklearn/cluster/k_means_.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,12 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None,
550550

551551
if sp.issparse(centers):
552552
centers = centers.toarray()
553+
554+
if len(centers) != k:
555+
raise ValueError('The shape of the inital centers (%s) '
556+
'does not match the number of clusters %i'
557+
% (centers.shape, k))
558+
553559
return centers
554560

555561

@@ -842,8 +848,8 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
842848
counts <= .001 * counts.max())
843849
# Pick new clusters amongst observations with a probability
844850
# proportional to their closeness to their center
845-
distance_to_centers = (centers[nearest_center] - X)
846-
distance_to_centers **=2
851+
distance_to_centers = np.asarray(centers[nearest_center] - X)
852+
distance_to_centers **= 2
847853
distance_to_centers = distance_to_centers.sum(axis=1)
848854
# Flip the ordering of the distances
849855
distance_to_centers -= distance_to_centers.max()

sklearn/cluster/tests/test_k_means.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,15 @@ def test_sparse_mb_k_means_callable_init():
321321
def test_init(X, k, random_state):
322322
return centers
323323

324-
mb_k_means = MiniBatchKMeans(init=test_init, random_state=42).fit(X_csr)
324+
# Small test to check that giving the wrong number of centers
325+
# raises a meaningful error
326+
assert_raises(ValueError,
327+
MiniBatchKMeans(init=test_init, random_state=42).fit,
328+
X_csr)
329+
330+
# Now check that the fit actually works
331+
mb_k_means = MiniBatchKMeans(n_clusters=3, init=test_init,
332+
random_state=42).fit(X_csr)
325333
_check_fitted_model(mb_k_means)
326334

327335

0 commit comments

Comments
 (0)