@@ -15,7 +15,7 @@ cimport numpy as np
1515cimport cython
1616
1717from ..utils.extmath import norm
18- from sklearn.utils.sparsefuncs_fast cimport add_row_csr
18+ from sklearn.utils.sparsefuncs_fast import assign_rows_csr
1919from sklearn.utils.fixes import bincount
2020
2121ctypedef np.float64_t DOUBLE
@@ -326,9 +326,8 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
326326 centers: array, shape (n_clusters, n_features)
327327 The resulting centers
328328 """
329- n_features = X.shape[1 ]
330-
331- cdef np.npy_intp cluster_id
329+ cdef int n_features = X.shape[1 ]
330+ cdef int curr_label
332331
333332 cdef np.ndarray[DOUBLE, ndim= 1 ] data = X.data
334333 cdef np.ndarray[int , ndim= 1 ] indices = X.indices
@@ -341,24 +340,25 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
341340 bincount(labels, minlength = n_clusters)
342341 cdef np.ndarray[np.npy_intp, ndim= 1 , mode= " c" ] empty_clusters = \
343342 np.where(n_samples_in_cluster == 0 )[0 ]
343+ cdef int n_empty_clusters = empty_clusters.shape[0 ]
344344
345345 # maybe also relocate small clusters?
346346
347- if empty_clusters.shape[ 0 ] > 0 :
347+ if n_empty_clusters > 0 :
348348 # find points to reassign empty clusters to
349- far_from_centers = distances.argsort()[::- 1 ]
349+ far_from_centers = distances.argsort()[::- 1 ][:n_empty_clusters]
350350
351- for i in range (empty_clusters.shape[ 0 ]):
352- cluster_id = empty_clusters[i]
351+ # XXX two relocated clusters could be close to each other
352+ assign_rows_csr(X, far_from_centers, empty_clusters, centers)
353353
354- # XXX two relocated clusters could be close to each other
355- centers[cluster_id] = 0.
356- add_row_csr(data, indices, indptr, far_from_centers[i],
357- centers[cluster_id])
358- n_samples_in_cluster[cluster_id] = 1
354+ for i in range (n_empty_clusters):
355+ n_samples_in_cluster[empty_clusters[i]] = 1
359356
360357 for i in range (labels.shape[0 ]):
361- add_row_csr(data, indices, indptr, i, centers[labels[i]])
358+ curr_label = labels[i]
359+ for ind in range (indptr[i], indptr[i + 1 ]):
360+ j = indices[ind]
361+ centers[curr_label, j] += data[ind]
362362
363363 centers /= n_samples_in_cluster[:, np.newaxis]
364364
0 commit comments