Skip to content

Commit f75ecd7

Browse files
amuellerogrisel
authored andcommitted
make pickle version test more stable (scikit-learn#7415)
1 parent 4d9fab5 commit f75ecd7

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

sklearn/tests/test_base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ def __getstate__(self):
324324
return self.__dict__
325325

326326

327+
class TreeBadVersion(DecisionTreeClassifier):
328+
def __getstate__(self):
329+
return dict(self.__dict__.items(), _sklearn_version="something")
330+
331+
327332
def test_pickle_version_warning():
328333
# check that warnings are raised when unpickling in a different version
329334

@@ -335,9 +340,9 @@ def test_pickle_version_warning():
335340
assert_no_warnings(pickle.loads, tree_pickle)
336341

337342
# check that warning is raised on different version
338-
tree_pickle_other = tree_pickle.replace(sklearn.__version__.encode(),
339-
b"something")
340-
message = ("Trying to unpickle estimator DecisionTreeClassifier from "
343+
tree = TreeBadVersion().fit(iris.data, iris.target)
344+
tree_pickle_other = pickle.dumps(tree)
345+
message = ("Trying to unpickle estimator TreeBadVersion from "
341346
"version {0} when using version {1}. This might lead to "
342347
"breaking code or invalid results. "
343348
"Use at your own risk.".format("something",
@@ -351,7 +356,7 @@ def test_pickle_version_warning():
351356
tree_pickle_noversion = pickle.dumps(tree)
352357
assert_false(b"version" in tree_pickle_noversion)
353358
message = message.replace("something", "pre-0.18")
354-
message = message.replace("DecisionTreeClassifier", "TreeNoVersion")
359+
message = message.replace("TreeBadVersion", "TreeNoVersion")
355360
# check we got the warning about using pre-0.18 pickle
356361
assert_warns_message(UserWarning, message, pickle.loads,
357362
tree_pickle_noversion)

0 commit comments

Comments
 (0)