Skip to content

Commit b4984e2

Browse files
SolutusImmensusglemaitre
authored andcommitted
[MRG+1] FIX Add some validation in the constructor of ParameterGrid (scikit-learn#11090)
1 parent ef8d22a commit b4984e2

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,9 @@ Misc
615615
acts as an upper bound on iterations.
616616
:issue:`#10982` by :user:`Juliet Lawton <julietcl>`
617617

618+
- Invalid input for :class:`model_selection.ParameterGrid` now raises TypeError.
619+
:issue:`10928` by :user:`Solutus Immensus <solutusimmensus>`
620+
618621
Changes to estimator checks
619622
---------------------------
620623

sklearn/model_selection/_search.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# License: BSD 3 clause
1414

1515
from abc import ABCMeta, abstractmethod
16-
from collections import Mapping, namedtuple, defaultdict, Sequence
16+
from collections import Mapping, namedtuple, defaultdict, Sequence, Iterable
1717
from functools import partial, reduce
1818
from itertools import product
1919
import operator
@@ -90,10 +90,26 @@ class ParameterGrid(object):
9090
"""
9191

9292
def __init__(self, param_grid):
93+
if not isinstance(param_grid, (Mapping, Iterable)):
94+
raise TypeError('Parameter grid is not a dict or '
95+
'a list ({!r})'.format(param_grid))
96+
9397
if isinstance(param_grid, Mapping):
9498
# wrap dictionary in a singleton list to support either dict
9599
# or list of dicts
96100
param_grid = [param_grid]
101+
102+
# check if all entries are dictionaries of lists
103+
for grid in param_grid:
104+
if not isinstance(grid, dict):
105+
raise TypeError('Parameter grid is not a '
106+
'dict ({!r})'.format(grid))
107+
for key in grid:
108+
if not isinstance(grid[key], Iterable):
109+
raise TypeError('Parameter grid value is not iterable '
110+
'(key={!r}, value={!r})'
111+
.format(key, grid[key]))
112+
97113
self.param_grid = param_grid
98114

99115
def __iter__(self):

sklearn/model_selection/tests/test_search.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414
import scipy.sparse as sp
15+
import pytest
1516

1617
from sklearn.utils.fixes import sp_version
1718
from sklearn.utils.testing import assert_equal
@@ -126,7 +127,19 @@ def assert_grid_iter_equals_getitem(grid):
126127
assert_equal(list(grid), [grid[i] for i in range(len(grid))])
127128

128129

130+
@pytest.mark.parametrize(
131+
"input, error_type, error_message",
132+
[(0, TypeError, 'Parameter grid is not a dict or a list (0)'),
133+
([{'foo': [0]}, 0], TypeError, 'Parameter grid is not a dict (0)'),
134+
({'foo': 0}, TypeError, "Parameter grid value is not iterable "
135+
"(key='foo', value=0)")]
136+
)
137+
def test_validate_parameter_grid_input(input, error_type, error_message):
138+
with pytest.raises(error_type, message=error_message):
139+
ParameterGrid(input)
140+
129141
def test_parameter_grid():
142+
130143
# Test basic properties of ParameterGrid.
131144
params1 = {"foo": [1, 2, 3]}
132145
grid1 = ParameterGrid(params1)

0 commit comments

Comments
 (0)