Skip to content

Commit 7effc37

Browse files
authored
MNT Warn when shuffle is False but random_state is not None (scikit-learn#15353)
1 parent 7281083 commit 7effc37

File tree

5 files changed

+37
-10
lines changed

5 files changed

+37
-10
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,11 @@ Changelog
484484
where one test set could be `n_classes` larger than another. Test sets should
485485
now be near-equally sized. :pr:`14704` by `Joel Nothman`_.
486486

487+
- |API| :class:`model_selection.KFold` and
488+
:class:`model_selection.StratifiedKFold` now raise a warning if
489+
`random_state` is set but `shuffle` is False. This will raise an error in
490+
0.24.
491+
487492
:mod:`sklearn.multioutput`
488493
..........................
489494

sklearn/linear_model/tests/test_logistic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,7 @@ def test_LogisticRegressionCV_GridSearchCV_elastic_net(multi_class):
14961496
X, y = make_classification(n_samples=100, n_classes=3, n_informative=3,
14971497
random_state=0)
14981498

1499-
cv = StratifiedKFold(5, random_state=0)
1499+
cv = StratifiedKFold(5)
15001500

15011501
l1_ratios = np.linspace(0, 1, 3)
15021502
Cs = np.logspace(-4, 4, 3)
@@ -1527,7 +1527,7 @@ def test_LogisticRegressionCV_GridSearchCV_elastic_net_ovr():
15271527
X, y = make_classification(n_samples=100, n_classes=3, n_informative=3,
15281528
random_state=0)
15291529
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
1530-
cv = StratifiedKFold(5, random_state=0)
1530+
cv = StratifiedKFold(5)
15311531

15321532
l1_ratios = np.linspace(0, 1, 3)
15331533
Cs = np.logspace(-4, 4, 3)
@@ -1770,7 +1770,7 @@ def test_scores_attribute_layout_elasticnet():
17701770
# the third dimension corresponds to l1_ratios.
17711771

17721772
X, y = make_classification(n_samples=1000, random_state=0)
1773-
cv = StratifiedKFold(n_splits=5, shuffle=False)
1773+
cv = StratifiedKFold(n_splits=5)
17741774

17751775
l1_ratios = [.1, .9]
17761776
Cs = [.1, 1, 10]

sklearn/model_selection/_split.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,15 @@ def __init__(self, n_splits, shuffle, random_state):
287287
raise TypeError("shuffle must be True or False;"
288288
" got {0}".format(shuffle))
289289

290+
if not shuffle and random_state is not None: # None is the default
291+
# TODO 0.24: raise a ValueError instead of a warning
292+
warnings.warn(
293+
'Setting a random_state has no effect since shuffle is '
294+
'False. This will raise an error in 0.24. You should leave '
295+
'random_state to its default (None), or set shuffle=True.',
296+
DeprecationWarning
297+
)
298+
290299
self.n_splits = n_splits
291300
self.shuffle = shuffle
292301
self.random_state = random_state
@@ -374,7 +383,8 @@ class KFold(_BaseKFold):
374383
If int, random_state is the seed used by the random number generator;
375384
If RandomState instance, random_state is the random number generator;
376385
If None, the random number generator is the RandomState instance used
377-
by `np.random`. Used when ``shuffle`` == True.
386+
by `np.random`. Only used when ``shuffle`` is True. This should be left
387+
to None if ``shuffle`` is False.
378388
379389
Examples
380390
--------
@@ -579,7 +589,8 @@ class StratifiedKFold(_BaseKFold):
579589
If int, random_state is the seed used by the random number generator;
580590
If RandomState instance, random_state is the random number generator;
581591
If None, the random number generator is the RandomState instance used
582-
by `np.random`. Used when ``shuffle`` == True.
592+
by `np.random`. Only used when ``shuffle`` is True. This should be left
593+
to None if ``shuffle`` is False.
583594
584595
Examples
585596
--------

sklearn/model_selection/tests/test_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ def test_search_cv_results_none_param():
12201220
X, y = [[1], [2], [3], [4], [5]], [0, 0, 0, 0, 1]
12211221
estimators = (DecisionTreeRegressor(), DecisionTreeClassifier())
12221222
est_parameters = {"random_state": [0, None]}
1223-
cv = KFold(random_state=0)
1223+
cv = KFold()
12241224

12251225
for est in estimators:
12261226
grid_search = GridSearchCV(est, est_parameters, cv=cv,
@@ -1294,7 +1294,7 @@ def test_grid_search_correct_score_results():
12941294

12951295
def test_fit_grid_point():
12961296
X, y = make_classification(random_state=0)
1297-
cv = StratifiedKFold(random_state=0)
1297+
cv = StratifiedKFold()
12981298
svc = LinearSVC(random_state=0)
12991299
scorer = make_scorer(accuracy_score)
13001300

@@ -1345,7 +1345,7 @@ def test_grid_search_with_multioutput_data():
13451345
random_state=0)
13461346

13471347
est_parameters = {"max_depth": [1, 2, 3, 4]}
1348-
cv = KFold(random_state=0)
1348+
cv = KFold()
13491349

13501350
estimators = [DecisionTreeRegressor(random_state=0),
13511351
DecisionTreeClassifier(random_state=0)]

sklearn/model_selection/tests/test_split.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def test_stratified_kfold_ratios(k, shuffle):
390390
distr = np.bincount(y) / len(y)
391391

392392
test_sizes = []
393-
skf = StratifiedKFold(k, random_state=0, shuffle=shuffle)
393+
random_state = None if not shuffle else 0
394+
skf = StratifiedKFold(k, random_state=random_state, shuffle=shuffle)
394395
for train, test in skf.split(X, y):
395396
assert_allclose(np.bincount(y[train]) / len(train), distr, atol=0.02)
396397
assert_allclose(np.bincount(y[test]) / len(test), distr, atol=0.02)
@@ -409,9 +410,10 @@ def test_stratified_kfold_label_invariance(k, shuffle):
409410
X = np.ones(len(y))
410411

411412
def get_splits(y):
413+
random_state = None if not shuffle else 0
412414
return [(list(train), list(test))
413415
for train, test
414-
in StratifiedKFold(k, random_state=0,
416+
in StratifiedKFold(k, random_state=random_state,
415417
shuffle=shuffle).split(X, y)]
416418

417419
splits_base = get_splits(y)
@@ -1582,3 +1584,12 @@ def test_leave_p_out_empty_trainset():
15821584
ValueError,
15831585
match='p=2 must be strictly less than the number of samples=2'):
15841586
next(cv.split(X, y, groups=[1, 2]))
1587+
1588+
1589+
@pytest.mark.parametrize('Klass', (KFold, StratifiedKFold))
1590+
def test_random_state_shuffle_false(Klass):
1591+
# passing a non-default random_state when shuffle=False makes no sense
1592+
# TODO 0.24: raise a ValueError instead of a warning
1593+
with pytest.warns(DeprecationWarning,
1594+
match='has no effect since shuffle is False'):
1595+
Klass(3, shuffle=False, random_state=0)

0 commit comments

Comments
 (0)