Skip to content

Commit 0706636

Browse files
committed
Merge pull request scikit-learn#5063 from amueller/bagging_input_validation
[MRG] test for accepted sparse matrix types
2 parents 98fc670 + 434ee95 commit 0706636

File tree

5 files changed

+36
-31
lines changed

5 files changed

+36
-31
lines changed

sklearn/ensemble/bagging.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def fit(self, X, y, sample_weight=None):
251251
random_state = check_random_state(self.random_state)
252252

253253
# Convert data
254-
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'])
254+
X, y = check_X_y(X, y, ['csr', 'csc'])
255255

256256
# Remap output
257257
n_samples, self.n_features_ = X.shape
@@ -587,7 +587,7 @@ def predict_proba(self, X):
587587
"""
588588
check_is_fitted(self, "classes_")
589589
# Check data
590-
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
590+
X = check_array(X, accept_sparse=['csr', 'csc'])
591591

592592
if self.n_features_ != X.shape[1]:
593593
raise ValueError("Number of features of the model must "
@@ -865,7 +865,7 @@ def predict(self, X):
865865
"""
866866
check_is_fitted(self, "estimators_features_")
867867
# Check data
868-
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
868+
X = check_array(X, accept_sparse=['csr', 'csc'])
869869

870870
# Parallel loop
871871
n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators,

sklearn/feature_selection/univariate_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def fit(self, X, y):
319319
self : object
320320
Returns self.
321321
"""
322-
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'])
322+
X, y = check_X_y(X, y, ['csr', 'csc'])
323323

324324
if not callable(self.score_func):
325325
raise TypeError("The score function should be a callable, %s (%s) "

sklearn/linear_model/coordinate_descent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,8 @@ def fit(self, X, y):
10741074
# by the model fitting itself
10751075
X = check_array(X, 'csc', copy=False)
10761076
if sparse.isspmatrix(X):
1077-
if not np.may_share_memory(reference_to_old_X.data, X.data):
1077+
if (hasattr(reference_to_old_X, "data") and
1078+
not np.may_share_memory(reference_to_old_X.data, X.data)):
10781079
# X is a sparse matrix and has been copied
10791080
copy_X = False
10801081
elif not np.may_share_memory(reference_to_old_X, X):

sklearn/linear_model/randomized_l1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def fit(self, X, y):
8888
self : object
8989
Returns an instance of self.
9090
"""
91-
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], y_numeric=True)
91+
X, y = check_X_y(X, y, ['csr', 'csc'], y_numeric=True)
9292
X = as_float_array(X, copy=False)
9393
n_samples, n_features = X.shape
9494

sklearn/utils/estimator_checks.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -269,34 +269,38 @@ def check_estimator_sparse_data(name, Estimator):
269269
rng = np.random.RandomState(0)
270270
X = rng.rand(40, 10)
271271
X[X < .8] = 0
272-
X = sparse.csr_matrix(X)
272+
X_csr = sparse.csr_matrix(X)
273273
y = (4 * rng.rand(40)).astype(np.int)
274-
# catch deprecation warnings
275-
with warnings.catch_warnings():
276-
if name in ['Scaler', 'StandardScaler']:
277-
estimator = Estimator(with_mean=False)
278-
else:
279-
estimator = Estimator()
280-
set_fast_parameters(estimator)
281-
# fit and predict
282-
try:
283-
estimator.fit(X, y)
284-
if hasattr(estimator, "predict"):
285-
estimator.predict(X)
286-
if hasattr(estimator, 'predict_proba'):
287-
estimator.predict_proba(X)
288-
except TypeError as e:
289-
if 'sparse' not in repr(e):
274+
for sparse_format in ['csr', 'csc', 'dok', 'lil', 'coo', 'dia', 'bsr']:
275+
X = X_csr.asformat(sparse_format)
276+
# catch deprecation warnings
277+
with warnings.catch_warnings():
278+
if name in ['Scaler', 'StandardScaler']:
279+
estimator = Estimator(with_mean=False)
280+
else:
281+
estimator = Estimator()
282+
set_fast_parameters(estimator)
283+
# fit and predict
284+
try:
285+
estimator.fit(X, y)
286+
if hasattr(estimator, "predict"):
287+
pred = estimator.predict(X)
288+
assert_equal(pred.shape, (X.shape[0],))
289+
if hasattr(estimator, 'predict_proba'):
290+
probs = estimator.predict_proba(X)
291+
assert_equal(probs.shape, (X.shape[0], 4))
292+
except TypeError as e:
293+
if 'sparse' not in repr(e):
294+
print("Estimator %s doesn't seem to fail gracefully on "
295+
"sparse data: error message state explicitly that "
296+
"sparse input is not supported if this is not the case."
297+
% name)
298+
raise
299+
except Exception:
290300
print("Estimator %s doesn't seem to fail gracefully on "
291-
"sparse data: error message state explicitly that "
292-
"sparse input is not supported if this is not the case."
293-
% name)
301+
"sparse data: it should raise a TypeError if sparse input "
302+
"is explicitly not supported." % name)
294303
raise
295-
except Exception:
296-
print("Estimator %s doesn't seem to fail gracefully on "
297-
"sparse data: it should raise a TypeError if sparse input "
298-
"is explicitly not supported." % name)
299-
raise
300304

301305

302306
def check_dtype_object(name, Estimator):

0 commit comments

Comments
 (0)