88#
99# License: BSD 3 clause
1010
11+ import warnings
1112import numpy as np
1213from scipy import sparse as sp
1314
1617from ..metrics .pairwise import pairwise_distances
1718from ..preprocessing import LabelEncoder
1819from ..utils .validation import check_array , check_X_y
20+ from ..utils .sparsefuncs import csc_median_axis_0
1921
2022
2123class NearestCentroid (BaseEstimator , ClassifierMixin ):
@@ -31,6 +33,12 @@ class NearestCentroid(BaseEstimator, ClassifierMixin):
3133 feature array. If metric is a string or callable, it must be one of
3234 the options allowed by metrics.pairwise.pairwise_distances for its
3335 metric parameter.
36+ The centroids for the samples corresponding to each class is the point
37+ from which the sum of the distances (according to the metric) of all
38+ samples that belong to that particular class are minimized.
39+ If the "manhattan" metric is provided, this centroid is the median and
40+ for all other metrics, the centroid is now set to be the mean.
41+
3442 shrink_threshold : float, optional (default = None)
3543 Threshold for shrinking centroids to remove features.
3644
@@ -86,8 +94,14 @@ def fit(self, X, y):
8694 y : array, shape = [n_samples]
8795 Target values (integers)
8896 """
89- X , y = check_X_y (X , y , ['csr' , 'csc' ])
90- if sp .issparse (X ) and self .shrink_threshold :
97+ # If X is sparse and the metric is "manhattan", store it in a csc
98+ # format is easier to calculate the median.
99+ if self .metric == 'manhattan' :
100+ X , y = check_X_y (X , y , ['csc' ])
101+ else :
102+ X , y = check_X_y (X , y , ['csr' , 'csc' ])
103+ is_X_sparse = sp .issparse (X )
104+ if is_X_sparse and self .shrink_threshold :
91105 raise ValueError ("threshold shrinking not supported"
92106 " for sparse input" )
93107
@@ -107,9 +121,23 @@ def fit(self, X, y):
107121 for cur_class in y_ind :
108122 center_mask = y_ind == cur_class
109123 nk [cur_class ] = np .sum (center_mask )
110- if sp . issparse ( X ) :
124+ if is_X_sparse :
111125 center_mask = np .where (center_mask )[0 ]
112- self .centroids_ [cur_class ] = X [center_mask ].mean (axis = 0 )
126+
127+ # XXX: Update other averaging methods according to the metrics.
128+ if self .metric == "manhattan" :
129+ # NumPy does not calculate median of sparse matrices.
130+ if not is_X_sparse :
131+ self .centroids_ [cur_class ] = np .median (X [center_mask ], axis = 0 )
132+ else :
133+ self .centroids_ [cur_class ] = csc_median_axis_0 (X [center_mask ])
134+ else :
135+ if self .metric != 'euclidean' :
136+ warnings .warn ("Averaging for metrics other than "
137+ "euclidean and manhattan not supported. "
138+ "The average is set to be the mean."
139+ )
140+ self .centroids_ [cur_class ] = X [center_mask ].mean (axis = 0 )
113141
114142 if self .shrink_threshold :
115143 dataset_centroid_ = np .mean (X , axis = 0 )
0 commit comments