Skip to content

Commit 30619ff

Browse files
committed
Merge pull request scikit-learn#3772 from MechCoder/manhattan_metric
[MRG+2] ENH: Patches Nearest Centroid for metric=manhattan for sparse and dense data
2 parents 535d1f6 + e871972 commit 30619ff

File tree

6 files changed

+142
-33
lines changed

6 files changed

+142
-33
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ Bug fixes
119119
estimator. It allows for instance to make bagging of a pipeline object.
120120
By `Arnaud Joly`_
121121

122+
- :class:`neighbors.NearestCentroid` now uses the median as the centroid
123+
when metric is set to ``manhattan``. It was using the mean before.
124+
By `Manoj Kumar`_
125+
122126
API changes summary
123127
-------------------
124128

sklearn/neighbors/nearest_centroid.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#
99
# License: BSD 3 clause
1010

11+
import warnings
1112
import numpy as np
1213
from scipy import sparse as sp
1314

@@ -16,6 +17,7 @@
1617
from ..metrics.pairwise import pairwise_distances
1718
from ..preprocessing import LabelEncoder
1819
from ..utils.validation import check_array, check_X_y
20+
from ..utils.sparsefuncs import csc_median_axis_0
1921

2022

2123
class NearestCentroid(BaseEstimator, ClassifierMixin):
@@ -31,6 +33,12 @@ class NearestCentroid(BaseEstimator, ClassifierMixin):
3133
feature array. If metric is a string or callable, it must be one of
3234
the options allowed by metrics.pairwise.pairwise_distances for its
3335
metric parameter.
36+
The centroids for the samples corresponding to each class is the point
37+
from which the sum of the distances (according to the metric) of all
38+
samples that belong to that particular class are minimized.
39+
If the "manhattan" metric is provided, this centroid is the median and
40+
for all other metrics, the centroid is now set to be the mean.
41+
3442
shrink_threshold : float, optional (default = None)
3543
Threshold for shrinking centroids to remove features.
3644
@@ -86,8 +94,14 @@ def fit(self, X, y):
8694
y : array, shape = [n_samples]
8795
Target values (integers)
8896
"""
89-
X, y = check_X_y(X, y, ['csr', 'csc'])
90-
if sp.issparse(X) and self.shrink_threshold:
97+
# If X is sparse and the metric is "manhattan", store it in a csc
98+
# format is easier to calculate the median.
99+
if self.metric == 'manhattan':
100+
X, y = check_X_y(X, y, ['csc'])
101+
else:
102+
X, y = check_X_y(X, y, ['csr', 'csc'])
103+
is_X_sparse = sp.issparse(X)
104+
if is_X_sparse and self.shrink_threshold:
91105
raise ValueError("threshold shrinking not supported"
92106
" for sparse input")
93107

@@ -107,9 +121,23 @@ def fit(self, X, y):
107121
for cur_class in y_ind:
108122
center_mask = y_ind == cur_class
109123
nk[cur_class] = np.sum(center_mask)
110-
if sp.issparse(X):
124+
if is_X_sparse:
111125
center_mask = np.where(center_mask)[0]
112-
self.centroids_[cur_class] = X[center_mask].mean(axis=0)
126+
127+
# XXX: Update other averaging methods according to the metrics.
128+
if self.metric == "manhattan":
129+
# NumPy does not calculate median of sparse matrices.
130+
if not is_X_sparse:
131+
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
132+
else:
133+
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask])
134+
else:
135+
if self.metric != 'euclidean':
136+
warnings.warn("Averaging for metrics other than "
137+
"euclidean and manhattan not supported. "
138+
"The average is set to be the mean."
139+
)
140+
self.centroids_[cur_class] = X[center_mask].mean(axis=0)
113141

114142
if self.shrink_threshold:
115143
dataset_centroid_ = np.mean(X, axis=0)

sklearn/neighbors/tests/test_nearest_centroid.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ def test_predict_translated_data():
125125
assert_array_equal(y_init, y_translate)
126126

127127

128+
def test_manhattan_metric():
129+
"""Test the manhattan metric."""
130+
131+
clf = NearestCentroid(metric='manhattan')
132+
clf.fit(X, y)
133+
dense_centroid = clf.centroids_
134+
clf.fit(X_csr, y)
135+
assert_array_equal(clf.centroids_, dense_centroid)
136+
assert_array_equal(dense_centroid, [[-1, -1], [1, 1]])
137+
138+
128139
if __name__ == "__main__":
129140
import nose
130141
nose.runmodule()

sklearn/preprocessing/imputation.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ..utils import check_array
1313
from ..utils import as_float_array
1414
from ..utils.fixes import astype
15+
from ..utils.sparsefuncs import _get_median
1516

1617
from ..externals import six
1718

@@ -31,34 +32,6 @@ def _get_mask(X, value_to_mask):
3132
return X == value_to_mask
3233

3334

34-
def _get_median(data, n_zeros):
35-
"""Compute the median of data with n_zeros additional zeros.
36-
37-
This function is used to support sparse matrices; it modifies data in-place
38-
"""
39-
n_elems = len(data) + n_zeros
40-
if not n_elems:
41-
return np.nan
42-
n_negative = np.count_nonzero(data < 0)
43-
middle, is_odd = divmod(n_elems, 2)
44-
data.sort()
45-
46-
if is_odd:
47-
return _get_elem_at_rank(middle, data, n_negative, n_zeros)
48-
49-
return (_get_elem_at_rank(middle - 1, data, n_negative, n_zeros) +
50-
_get_elem_at_rank(middle, data, n_negative, n_zeros)) / 2.
51-
52-
53-
def _get_elem_at_rank(rank, data, n_negative, n_zeros):
54-
"""Find the value in data augmented with n_zeros for the given rank"""
55-
if rank < n_negative:
56-
return data[rank]
57-
if rank - n_negative < n_zeros:
58-
return 0
59-
return data[rank - n_zeros]
60-
61-
6235
def _most_frequent(array, extra_value, n_repeat):
6336
"""Compute the most frequent value in a 1d array extended with
6437
[extra_value] * n_repeat, where extra_value is assumed to be not part

sklearn/utils/sparsefuncs.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,63 @@ def count_nonzero(X, axis=None, sample_weight=None):
342342
weights=weights)
343343
else:
344344
raise ValueError('Unsupported axis: {0}'.format(axis))
345+
346+
347+
def _get_median(data, n_zeros):
348+
"""Compute the median of data with n_zeros additional zeros.
349+
350+
This function is used to support sparse matrices; it modifies data in-place
351+
"""
352+
n_elems = len(data) + n_zeros
353+
if not n_elems:
354+
return np.nan
355+
n_negative = np.count_nonzero(data < 0)
356+
middle, is_odd = divmod(n_elems, 2)
357+
data.sort()
358+
359+
if is_odd:
360+
return _get_elem_at_rank(middle, data, n_negative, n_zeros)
361+
362+
return (_get_elem_at_rank(middle - 1, data, n_negative, n_zeros) +
363+
_get_elem_at_rank(middle, data, n_negative, n_zeros)) / 2.
364+
365+
366+
def _get_elem_at_rank(rank, data, n_negative, n_zeros):
367+
"""Find the value in data augmented with n_zeros for the given rank"""
368+
if rank < n_negative:
369+
return data[rank]
370+
if rank - n_negative < n_zeros:
371+
return 0
372+
return data[rank - n_zeros]
373+
374+
375+
def csc_median_axis_0(X):
376+
"""Find the median across axis 0 of a CSC matrix.
377+
It is equivalent to doing np.median(X, axis=0).
378+
379+
Parameters
380+
----------
381+
X : CSC sparse matrix, shape (n_samples, n_features)
382+
Input data.
383+
384+
Returns
385+
-------
386+
median : ndarray, shape (n_features,)
387+
Median.
388+
389+
"""
390+
if not isinstance(X, sp.csc_matrix):
391+
raise TypeError("Expected matrix of CSC format, got %s" % X.format)
392+
393+
indptr = X.indptr
394+
n_samples, n_features = X.shape
395+
median = np.zeros(n_features)
396+
397+
for f_ind, (start, end) in enumerate(zip(indptr[:-1], indptr[1:])):
398+
399+
# Prevent modifying X in place
400+
data = np.copy(X.data[start: end])
401+
nz = n_samples - data.size
402+
median[f_ind] = _get_median(data, nz)
403+
404+
return median

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
inplace_row_scale,
1111
inplace_swap_row, inplace_swap_column,
1212
min_max_axis,
13-
count_nonzero)
13+
count_nonzero, csc_median_axis_0)
1414
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
1515
from sklearn.utils.testing import assert_raises
1616

@@ -359,3 +359,36 @@ def test_count_nonzero():
359359

360360
assert_raises(TypeError, count_nonzero, X_csc)
361361
assert_raises(ValueError, count_nonzero, X_csr, axis=2)
362+
363+
364+
def test_csc_row_median():
365+
"""Test csc_row_median actually calculates the median."""
366+
367+
# Test that it gives the same output when X is dense.
368+
rng = np.random.RandomState(0)
369+
X = rng.rand(100, 50)
370+
dense_median = np.median(X, axis=0)
371+
csc = sp.csc_matrix(X)
372+
sparse_median = csc_median_axis_0(csc)
373+
assert_array_equal(sparse_median, dense_median)
374+
375+
# Test that it gives the same output when X is sparse
376+
X = rng.rand(51, 100)
377+
X[X < 0.7] = 0.0
378+
ind = rng.randint(0, 50, 10)
379+
X[ind] = -X[ind]
380+
csc = sp.csc_matrix(X)
381+
dense_median = np.median(X, axis=0)
382+
sparse_median = csc_median_axis_0(csc)
383+
assert_array_equal(sparse_median, dense_median)
384+
385+
# Test for toy data.
386+
X = [[0, -2], [-1, -1], [1, 0], [2, 1]]
387+
csc = sp.csc_matrix(X)
388+
assert_array_equal(csc_median_axis_0(csc), np.array([0.5, -0.5]))
389+
X = [[0, -2], [-1, -5], [1, -3]]
390+
csc = sp.csc_matrix(X)
391+
assert_array_equal(csc_median_axis_0(csc), np.array([0., -3]))
392+
393+
# Test that it raises an Error for non-csc matrices.
394+
assert_raises(TypeError, csc_median_axis_0, sp.csr_matrix(X))

0 commit comments

Comments
 (0)