Skip to content

Commit 91450c6

Browse files
committed
FIX: Wrap csc_row_median around the _get_median imputer function
1 parent 6f41bd1 commit 91450c6

File tree

3 files changed

+16
-47
lines changed

3 files changed

+16
-47
lines changed

sklearn/neighbors/nearest_centroid.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def fit(self, X, y):
8888
Target values (integers)
8989
"""
9090
X, y = check_X_y(X, y, ['csc'])
91-
X_sparse = sp.issparse(X)
92-
if X_sparse and self.shrink_threshold:
91+
is_X_sparse = sp.issparse(X)
92+
if is_X_sparse and self.shrink_threshold:
9393
raise ValueError("threshold shrinking not supported"
9494
" for sparse input")
9595

@@ -109,13 +109,13 @@ def fit(self, X, y):
109109
for cur_class in y_ind:
110110
center_mask = y_ind == cur_class
111111
nk[cur_class] = np.sum(center_mask)
112-
if X_sparse:
112+
if is_X_sparse:
113113
center_mask = np.where(center_mask)[0]
114114

115115
# XXX: Update other averaging methods according to the metrics.
116116
if self.metric == "manhattan":
117117
# NumPy does not calculate median of sparse matrices.
118-
if not X_sparse:
118+
if not is_X_sparse:
119119
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
120120
else:
121121
self.centroids_[cur_class] = csc_row_median(X[center_mask])

sklearn/utils/sparsefuncs.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def count_nonzero(X, axis=None, sample_weight=None):
348348
def csc_row_median(csc):
349349
"""
350350
Find the median across axis 0 of a CSC matrix.
351-
Equivalent to doing np.median(X, axis=0)
351+
Wrapper aound the _get_median function in imputer and is equivalent
352+
to doing np.median(X, axis=0)
352353
353354
Parameters
354355
----------
@@ -361,55 +362,20 @@ def csc_row_median(csc):
361362
Median.
362363
363364
"""
365+
from ..preprocessing.imputation import _get_median
366+
364367
if not isinstance(csc, sp.csc_matrix):
365-
warnings.warn("Non CSC matix passed. Will convert to CSC format.")
366-
csc = sp.csc_matrix(csc)
368+
raise TypeError("Expected matrix of CSC format, got %s" % csc.format)
367369

368370
indptr = csc.indptr
369371
n_samples, n_features = csc.shape
370-
371-
# Highly likely that the median of a sparse matrix is zero.
372-
# Remains zero if the if/else conditions are not checked below.
373372
median = np.zeros(n_features)
374373

375374
for f_ind, ptr in enumerate(indptr[:-1]):
376-
sorted_nonzero = np.sort(csc.data[ptr: indptr[f_ind + 1]])
377-
nz = n_samples - sorted_nonzero.size
378-
zero_ind = np.searchsorted(sorted_nonzero, 0)
379-
neg_idx = sorted_nonzero[: zero_ind]
380-
pos_idx = sorted_nonzero[zero_ind: ]
381-
odd = n_samples % 2
382-
mid_ind = n_samples // 2
383-
384-
if odd:
385-
# Number of negative terms is greater then (n_features + 1) / 2
386-
# which implies the median is negative.
387-
if zero_ind > mid_ind:
388-
median[f_ind] = sorted_nonzero[mid_ind]
389-
390-
# The sum of the negative terms and the number of zeros is less
391-
# than the (n_features + 1) / 2 which implies the median is positive.
392-
elif zero_ind + nz <= mid_ind:
393-
median[f_ind] = pos_idx[mid_ind - nz - neg_idx.size]
394375

395-
else:
396-
# The first two conditions are highly unlikely.
397-
# When the n_features / 2 is the last negative term and
398-
# (n_features / 2) + 1 is zero.
399-
if zero_ind == mid_ind:
400-
median[f_ind] = neg_idx[-1] / 2.
401-
402-
# When the n_features / 2 is zero and (n_features / 2) + 1
403-
# is the first positive term.
404-
elif neg_idx.size + nz == mid_ind:
405-
median[f_ind] = pos_idx[0] / 2.
406-
407-
# Same comments as for the odd case.
408-
elif zero_ind > mid_ind:
409-
median[f_ind] = (sorted_nonzero[mid_ind - 1] +
410-
sorted_nonzero[mid_ind]) / 2.
411-
elif zero_ind + nz < mid_ind:
412-
npz = mid_ind - nz - neg_idx.size
413-
median[f_ind] = (pos_idx[npz - 1] + pos_idx[npz]) / 2.
376+
# Prevent modifying csc in place
377+
data = np.copy(csc.data[ptr: indptr[f_ind + 1]])
378+
nz = n_samples - data.size
379+
median[f_ind] = _get_median(data, nz)
414380

415381
return median

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,6 @@ def test_csc_row_median():
389389
X = [[0, -2], [-1, -5], [1, -3]]
390390
csc = sp.csc_matrix(X)
391391
assert_array_equal(csc_row_median(csc), np.array([0., -3]))
392+
393+
# Test that it raises an Error for non-csc matrices.
394+
assert_raises(TypeError, csc_row_median, sp.csr_matrix(X))

0 commit comments

Comments
 (0)