|
11 | 11 | from packaging.version import Version |
12 | 12 | from sklearn import __version__ as sklearn_version |
13 | 13 | from sklearn import datasets |
14 | | -from sklearn.linear_model import LinearRegression, LogisticRegression |
| 14 | +from sklearn.linear_model import LogisticRegression |
15 | 15 | from sklearn.model_selection import GridSearchCV |
16 | 16 | from sklearn.pipeline import make_pipeline |
17 | 17 |
|
@@ -64,68 +64,67 @@ def test_ColumnSelector_in_gridsearch(): |
64 | 64 |
|
65 | 65 |
|
66 | 66 | def test_ColumnSelector_with_dataframe(): |
67 | | - boston = datasets.load_boston() |
68 | | - df_in = pd.DataFrame(boston.data, columns=boston.feature_names) |
69 | | - df_out = ColumnSelector(cols=("ZN", "CRIM")).transform(df_in) |
70 | | - assert df_out.shape == (506, 2) |
| 67 | + iris = datasets.load_iris() |
| 68 | + df_in = pd.DataFrame(iris.data, columns=iris.feature_names) |
| 69 | + df_out = ColumnSelector(cols=("sepal length (cm)", "sepal width (cm)")).transform( |
| 70 | + df_in |
| 71 | + ) |
| 72 | + assert df_out.shape == (150, 2) |
71 | 73 |
|
72 | 74 |
|
73 | 75 | def test_ColumnSelector_with_dataframe_and_int_columns(): |
74 | | - boston = datasets.load_boston() |
75 | | - df_in = pd.DataFrame(boston.data, columns=boston.feature_names) |
76 | | - df_out_str = ColumnSelector(cols=("INDUS", "CHAS")).transform(df_in) |
| 76 | + iris = datasets.load_iris() |
| 77 | + df_in = pd.DataFrame(iris.data, columns=iris.feature_names) |
| 78 | + df_out_str = ColumnSelector( |
| 79 | + cols=("petal length (cm)", "petal width (cm)") |
| 80 | + ).transform(df_in) |
77 | 81 | df_out_int = ColumnSelector(cols=(2, 3)).transform(df_in) |
78 | 82 |
|
79 | 83 | np.testing.assert_array_equal(df_out_str[:, 0], df_out_int[:, 0]) |
80 | 84 | np.testing.assert_array_equal(df_out_str[:, 1], df_out_int[:, 1]) |
81 | 85 |
|
82 | 86 |
|
83 | 87 | def test_ColumnSelector_with_dataframe_drop_axis(): |
84 | | - boston = datasets.load_boston() |
85 | | - df_in = pd.DataFrame(boston.data, columns=boston.feature_names) |
86 | | - X1_out = ColumnSelector(cols="ZN", drop_axis=True).transform(df_in) |
87 | | - assert X1_out.shape == (506,) |
| 88 | + iris = datasets.load_iris() |
| 89 | + df_in = pd.DataFrame(iris.data, columns=iris.feature_names) |
| 90 | + X1_out = ColumnSelector(cols="petal length (cm)", drop_axis=True).transform(df_in) |
| 91 | + assert X1_out.shape == (150,) |
88 | 92 |
|
89 | | - X1_out = ColumnSelector(cols=("ZN",), drop_axis=True).transform(df_in) |
90 | | - assert X1_out.shape == (506,) |
| 93 | + X1_out = ColumnSelector(cols=("petal length (cm)",), drop_axis=True).transform( |
| 94 | + df_in |
| 95 | + ) |
| 96 | + assert X1_out.shape == (150,) |
91 | 97 |
|
92 | | - X1_out = ColumnSelector(cols="ZN").transform(df_in) |
93 | | - assert X1_out.shape == (506, 1) |
| 98 | + X1_out = ColumnSelector(cols="petal length (cm)").transform(df_in) |
| 99 | + assert X1_out.shape == (150, 1) |
94 | 100 |
|
95 | | - X1_out = ColumnSelector(cols=("ZN",)).transform(df_in) |
96 | | - assert X1_out.shape == (506, 1) |
| 101 | + X1_out = ColumnSelector(cols=("petal length (cm)",)).transform(df_in) |
| 102 | + assert X1_out.shape == (150, 1) |
97 | 103 |
|
98 | 104 |
|
99 | 105 | def test_ColumnSelector_with_dataframe_in_gridsearch(): |
100 | | - boston = datasets.load_boston() |
101 | | - X = pd.DataFrame(boston.data, columns=boston.feature_names) |
102 | | - y = boston.target |
103 | | - pipe = make_pipeline(ColumnSelector(), LinearRegression()) |
| 106 | + iris = datasets.load_iris() |
| 107 | + X = pd.DataFrame(iris.data, columns=iris.feature_names) |
| 108 | + y = iris.target |
| 109 | + pipe = make_pipeline(ColumnSelector(), LogisticRegression()) |
104 | 110 | grid = { |
105 | | - "columnselector__cols": [["ZN", "RM"], ["ZN", "RM", "AGE"], "ZN", ["RM"]], |
106 | | - "linearregression__copy_X": [True, False], |
107 | | - "linearregression__fit_intercept": [True, False], |
| 111 | + "columnselector__cols": [ |
| 112 | + ["petal length (cm)", "petal width(cm)"], |
| 113 | + ["sepal length (cm)", "sepal width (cm)", "petal width(cm)"], |
| 114 | + ], |
108 | 115 | } |
109 | 116 |
|
110 | | - if Version(sklearn_version) < Version("0.24.1"): |
111 | | - gsearch1 = GridSearchCV( |
112 | | - estimator=pipe, |
113 | | - param_grid=grid, |
114 | | - cv=5, |
115 | | - n_jobs=1, |
116 | | - iid=False, |
117 | | - scoring="neg_mean_squared_error", |
118 | | - refit=False, |
119 | | - ) |
120 | | - else: |
121 | | - gsearch1 = GridSearchCV( |
122 | | - estimator=pipe, |
123 | | - param_grid=grid, |
124 | | - cv=5, |
125 | | - n_jobs=1, |
126 | | - scoring="neg_mean_squared_error", |
127 | | - refit=False, |
128 | | - ) |
| 117 | + gsearch1 = GridSearchCV( |
| 118 | + estimator=pipe, |
| 119 | + param_grid=grid, |
| 120 | + cv=5, |
| 121 | + n_jobs=1, |
| 122 | + scoring="accuracy", |
| 123 | + refit=False, |
| 124 | + ) |
129 | 125 |
|
130 | 126 | gsearch1.fit(X, y) |
131 | | - assert gsearch1.best_params_["columnselector__cols"] == ["ZN", "RM", "AGE"] |
| 127 | + assert gsearch1.best_params_["columnselector__cols"] == [ |
| 128 | + "petal length (cm)", |
| 129 | + "petal width(cm)", |
| 130 | + ] |
0 commit comments