Skip to content

Commit 814223c

Browse files
jnothmanMechCoder
authored andcommitted
[MRG+1] FIX support memmap scalars as CV scores (scikit-learn#6789)
* FIX support memmap scalars as CV scores * FIX test for Python 3.5 and NumPy 1.12
1 parent d12bc46 commit 814223c

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

sklearn/cross_validation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,13 @@ def _score(estimator, X_test, y_test, scorer):
16481648
score = scorer(estimator, X_test)
16491649
else:
16501650
score = scorer(estimator, X_test, y_test)
1651+
if hasattr(score, 'item'):
1652+
try:
1653+
# e.g. unwrap memmapped scalars
1654+
score = score.item()
1655+
except ValueError:
1656+
# non-scalar?
1657+
pass
16511658
if not isinstance(score, numbers.Number):
16521659
raise ValueError("scoring must return a number, got %s (%s) instead."
16531660
% (str(score), type(score)))

sklearn/model_selection/_validation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ def _score(estimator, X_test, y_test, scorer):
301301
score = scorer(estimator, X_test)
302302
else:
303303
score = scorer(estimator, X_test, y_test)
304+
if hasattr(score, 'item'):
305+
try:
306+
# e.g. unwrap memmapped scalars
307+
score = score.item()
308+
except ValueError:
309+
# non-scalar?
310+
pass
304311
if not isinstance(score, numbers.Number):
305312
raise ValueError("scoring must return a number, got %s (%s) instead."
306313
% (str(score), type(score)))

sklearn/model_selection/tests/test_validation.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import sys
55
import warnings
6+
import tempfile
7+
import os
68

79
import numpy as np
810
from scipy.sparse import coo_matrix, csr_matrix
@@ -769,3 +771,22 @@ def test_cross_val_predict_with_method():
769771
predictions = cross_val_predict(est, X, y, method=method,
770772
cv=kfold)
771773
assert_array_almost_equal(expected_predictions, predictions)
774+
775+
776+
def test_score_memmap():
777+
# Ensure a scalar score of memmap type is accepted
778+
iris = load_iris()
779+
X, y = iris.data, iris.target
780+
clf = MockClassifier()
781+
tf = tempfile.NamedTemporaryFile(mode='wb', delete=False)
782+
tf.write(b'Hello world!!!!!')
783+
tf.close()
784+
scores = np.memmap(tf.name, dtype=float)
785+
score = np.memmap(tf.name, shape=(), mode='w+', dtype=float)
786+
try:
787+
cross_val_score(clf, X, y, scoring=lambda est, X, y: score)
788+
# non-scalar should still fail
789+
assert_raises(ValueError, cross_val_score, clf, X, y,
790+
scoring=lambda est, X, y: scores)
791+
finally:
792+
os.unlink(tf.name)

0 commit comments

Comments
 (0)