Skip to content

Commit c0657ce

Browse files
committed
ENH less copying in validation for neighbors
1 parent 912823f commit c0657ce

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

sklearn/neighbors/base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,7 @@ def _fit(self, X):
173173
self._fit_method = 'kd_tree'
174174
return self
175175

176-
X = safe_asarray(X)
177-
178-
if X.ndim != 2:
179-
raise ValueError("data type not understood")
176+
X = atleast2d_or_csr(X, copy=False)
180177

181178
n_samples = X.shape[0]
182179
if n_samples == 0:
@@ -189,7 +186,7 @@ def _fit(self, X):
189186
if self.effective_metric_ not in VALID_METRICS_SPARSE['brute']:
190187
raise ValueError("metric '%s' not valid for sparse input"
191188
% self.effective_metric_)
192-
self._fit_X = X.tocsr()
189+
self._fit_X = X.copy()
193190
self._tree = None
194191
self._fit_method = 'brute'
195192
return self

sklearn/utils/tests/test_validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ def test_as_float_array():
7070
assert_false(np.isnan(M).any())
7171

7272

73+
def test_atleast2d_or_sparse():
74+
for typ in [sp.csr_matrix, sp.dok_matrix, sp.lil_matrix, sp.coo_matrix]:
75+
X = typ(np.arange(9, dtype=float).reshape(3, 3))
76+
77+
Y = atleast2d_or_csr(X, copy=True)
78+
assert_true(isinstance(Y, sp.csr_matrix))
79+
Y.data[:] = 1
80+
assert_array_equal(X.toarray().ravel(), np.arange(9))
81+
82+
Y = atleast2d_or_csc(X, copy=False)
83+
Y.data[:] = 4
84+
assert_true(np.all(X.data == 4)
85+
if isinstance(X, sp.csc_matrix)
86+
else np.all(X.toarray().ravel() == np.arange(9)))
87+
88+
Y = atleast2d_or_csr(X, dtype=np.float32)
89+
assert_true(Y.dtype == np.float32)
90+
91+
7392
def test_check_arrays_exceptions():
7493
"""Check that invalid arguments raise appropriate exceptions"""
7594
assert_raises(ValueError, check_arrays, [0], [0, 1])

sklearn/utils/validation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ def array2d(X, dtype=None, order=None, copy=False, force_all_finite=True):
117117

118118

119119
def _atleast2d_or_sparse(X, dtype, order, copy, sparse_class, convmethod,
120-
force_all_finite):
120+
check_same_type, force_all_finite):
121121
if sparse.issparse(X):
122-
if dtype is None or X.dtype == dtype:
122+
if check_same_type(X) and X.dtype == dtype:
123+
X = getattr(X, convmethod)(copy=copy)
124+
elif dtype is None or X.dtype == dtype:
123125
X = getattr(X, convmethod)()
124126
else:
125127
X = sparse_class(X, dtype=dtype)
@@ -139,7 +141,8 @@ def atleast2d_or_csc(X, dtype=None, order=None, copy=False,
139141
Also, converts np.matrix to np.ndarray.
140142
"""
141143
return _atleast2d_or_sparse(X, dtype, order, copy, sparse.csc_matrix,
142-
"tocsc", force_all_finite)
144+
"tocsc", sparse.isspmatrix_csc,
145+
force_all_finite)
143146

144147

145148
def atleast2d_or_csr(X, dtype=None, order=None, copy=False,
@@ -149,7 +152,8 @@ def atleast2d_or_csr(X, dtype=None, order=None, copy=False,
149152
Also, converts np.matrix to np.ndarray.
150153
"""
151154
return _atleast2d_or_sparse(X, dtype, order, copy, sparse.csr_matrix,
152-
"tocsr", force_all_finite)
155+
"tocsr", sparse.isspmatrix_csr,
156+
force_all_finite)
153157

154158

155159
def _num_samples(x):

0 commit comments

Comments
 (0)