Skip to content

Commit 6009c05

Browse files
committed
ENH: more control on reassignment in MiniBachKMeans
1 parent 5d97a11 commit 6009c05

File tree

1 file changed

+47
-7
lines changed

1 file changed

+47
-7
lines changed

sklearn/cluster/k_means_.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,8 @@ def score(self, X):
815815
def _mini_batch_step(X, x_squared_norms, centers, counts,
816816
old_center_buffer, compute_squared_diff,
817817
distances=None, random_reassign=False,
818-
random_state=None):
818+
random_state=None, reassignment_ratio=.01,
819+
verbose=False):
819820
"""Incremental update of the centers for the Minibatch K-Means algorithm
820821
821822
Parameters
@@ -837,6 +838,26 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
837838
distances: array, dtype float64, shape (n_samples), optional
838839
If not None, should be a pre-allocated array that will be used to store
839840
the distances of each sample to it's closest center.
841+
842+
random_state: integer or numpy.RandomState, optional
843+
The generator used to initialize the centers. If an integer is
844+
given, it fixes the seed. Defaults to the global numpy random
845+
number generator.
846+
847+
random_reassign: boolean, optional
848+
If True, centers with very low counts are
849+
randomly-reassigned to observations in dense areas.
850+
851+
reassignment_ratio: float, optional
852+
Control the fraction of the maximum number of counts for a
853+
center to be reassigned. A higher value means that low count
854+
centers are more easily reassigned, which means that the
855+
model will take longer to converge, but should converge in a
856+
better clustering.
857+
858+
verbose: bool, optional
859+
Controls the verbosity
860+
840861
"""
841862
# Perform label assignement to nearest centers
842863
nearest_center, inertia = _labels_inertia(X, x_squared_norms, centers,
@@ -845,7 +866,7 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
845866
random_state = check_random_state(random_state)
846867
# Reassign clusters that have very low counts
847868
to_reassign = np.logical_or((counts <= 1),
848-
counts <= .001 * counts.max())
869+
counts <= reassignment_ratio * counts.max())
849870
# Pick new clusters amongst observations with a probability
850871
# proportional to their closeness to their center
851872
distance_to_centers = np.asarray(centers[nearest_center] - X)
@@ -859,6 +880,11 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
859880
new_centers = np.searchsorted(distance_to_centers.cumsum(),
860881
rand_vals)
861882
new_centers = X[new_centers]
883+
if verbose:
884+
n_reassigns = to_reassign.sum()
885+
if n_reassigns:
886+
print("[_mini_batch_step] Reassigning %i cluster centers."
887+
% n_reassigns)
862888
centers[to_reassign] = new_centers
863889

864890
# implementation for the sparse CSR reprensation completely written in
@@ -1030,6 +1056,14 @@ class MiniBatchKMeans(KMeans):
10301056
given, it fixes the seed. Defaults to the global numpy random
10311057
number generator.
10321058
1059+
reassignment_ratio: float, optional
1060+
Control the fraction of the maximum number of counts for a
1061+
center to be reassigned. A higher value means that low count
1062+
centers are more easily reassigned, which means that the
1063+
model will take longer to converge, but should converge in a
1064+
better clustering.
1065+
1066+
10331067
Attributes
10341068
----------
10351069
@@ -1053,7 +1087,8 @@ class MiniBatchKMeans(KMeans):
10531087
def __init__(self, n_clusters=8, init='k-means++', max_iter=100,
10541088
batch_size=100, verbose=0, compute_labels=True,
10551089
random_state=None, tol=0.0, max_no_improvement=10,
1056-
init_size=None, n_init=3, k=None):
1090+
init_size=None, n_init=3, k=None,
1091+
reassignment_ratio=0.01):
10571092

10581093
super(MiniBatchKMeans, self).__init__(
10591094
n_clusters=n_clusters, init=init, max_iter=max_iter,
@@ -1064,6 +1099,7 @@ def __init__(self, n_clusters=8, init='k-means++', max_iter=100,
10641099
self.batch_size = batch_size
10651100
self.compute_labels = compute_labels
10661101
self.init_size = init_size
1102+
self.reassignment_ratio = reassignment_ratio
10671103

10681104
def fit(self, X, y=None):
10691105
"""Compute the centroids on X by chunking it into mini-batches.
@@ -1143,7 +1179,7 @@ def fit(self, X, y=None):
11431179
batch_inertia, centers_squared_diff = _mini_batch_step(
11441180
X_valid, x_squared_norms[validation_indices],
11451181
cluster_centers, counts, old_center_buffer, False,
1146-
distances=distances)
1182+
distances=distances, verbose=self.verbose)
11471183

11481184
# Keep only the best cluster centers across independent inits on
11491185
# the common validation set
@@ -1175,7 +1211,9 @@ def fit(self, X, y=None):
11751211
old_center_buffer, tol > 0.0, distances=distances,
11761212
random_reassign=(iteration_idx + 1) % (10 +
11771213
self.counts_.min()) == 0,
1178-
random_state=self.random_state)
1214+
random_state=self.random_state,
1215+
reassignment_ratio=self.reassignment_ratio,
1216+
verbose=self.verbose)
11791217

11801218
# Monitor convergence and do early stopping if necessary
11811219
if _mini_batch_convergence(
@@ -1224,14 +1262,16 @@ def partial_fit(self, X, y=None):
12241262
else:
12251263
# The lower the minimum count is, the more we do random
12261264
# reassignement, however, we don't want to do random
1227-
# reassignement to often, to allow for building up counts
1265+
# reassignement too often, to allow for building up counts
12281266
random_reassign = self.random_state.randint(10 * (1 +
12291267
self.counts_.min())) == 0
12301268

12311269
_mini_batch_step(X, x_squared_norms, self.cluster_centers_,
12321270
self.counts_, np.zeros(0, np.double), 0,
12331271
random_reassign=random_reassign,
1234-
random_state=self.random_state)
1272+
random_state=self.random_state,
1273+
reassignment_ratio=self.reassignment_ratio,
1274+
verbose=self.verbose)
12351275

12361276
if self.compute_labels:
12371277
self.labels_, self.inertia_ = _labels_inertia(

0 commit comments

Comments
 (0)