Skip to content

Commit eef3bb9

Browse files
authored
Merge pull request rasbt#342 from jrbourbeau/add_decision_regions_kwargs
Adds style related dictionaries to plot_decision_regions
2 parents 66ba1b1 + 3990f22 commit eef3bb9

File tree

7 files changed

+292
-73
lines changed

7 files changed

+292
-73
lines changed

docs/sources/CHANGELOG.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ The CHANGELOG for the current development version is available at
3535
- Apriori code is faster due to optimization in `onehot transformation` and the amount of candidates generated by the `apriori` algorithm. ([#327](https://github.com/rasbt/mlxtend/pull/327) by [Jakub Smid](https://github.com/jaksmid))
3636
- The `OnehotTransactions` class (which is typically often used in combination with the `apriori` function for association rule mining) is now more memory efficient as it uses boolean arrays instead of integer arrays. In addition, the `OnehotTransactions` class can be now be provided with `sparse` argument to generate sparse representations of the `onehot` matrix to further improve memory efficiency. ([#328](https://github.com/rasbt/mlxtend/pull/328) by [Jakub Smid](https://github.com/jaksmid))
3737
- The `OneHotTransactions` has been deprecated and replaced by the `TransactionEncoder`. ([#332](https://github.com/rasbt/mlxtend/pull/332)
38+
- The `plot_decision_regions` function now has three new parameters, `scatter_kwargs`, `contourf_kwargs`, and `scatter_highlight_kwargs`, that can be used to modify the plotting style. ([#342](https://github.com/rasbt/mlxtend/pull/342) by [James Bourbeau](https://github.com/jrbourbeau))
39+
3840

3941
##### Bug Fixes
4042

@@ -55,7 +57,7 @@ The CHANGELOG for the current development version is available at
5557

5658
- New `store_train_meta_features` parameter for `fit` in StackingCVRegressor. if True, train meta-features are stored in `self.train_meta_features_`.
5759
New `pred_meta_features` method for `StackingCVRegressor`. People can get test meta-features using this method. ([#294](https://github.com/rasbt/mlxtend/pull/294) via [takashioya](https://github.com/takashioya))
58-
- The new `store_train_meta_features` attribute and `pred_meta_features` method for the `StackingCVRegressor` were also added to the `StackingRegressor`, `StackingClassifier`, and `StackingCVClassifier` ([#299](https://github.com/rasbt/mlxtend/pull/299) & [#300](https://github.com/rasbt/mlxtend/pull/300))
60+
- The new `store_train_meta_features` attribute and `pred_meta_features` method for the `StackingCVRegressor` were also added to the `StackingRegressor`, `StackingClassifier`, and `StackingCVClassifier` ([#299](https://github.com/rasbt/mlxtend/pull/299) & [#300](https://github.com/rasbt/mlxtend/pull/300))
5961
- New function (`evaluate.mcnemar_tables`) for creating multiple 2x2 contigency from model predictions arrays that can be used in multiple McNemar (post-hoc) tests or Cochran's Q or F tests, etc. ([#307](https://github.com/rasbt/mlxtend/issues/307))
6062
- New function (`evaluate.cochrans_q`) for performing Cochran's Q test to compare the accuracy of multiple classifiers. ([#310](https://github.com/rasbt/mlxtend/issues/310))
6163

@@ -84,8 +86,8 @@ The CHANGELOG for the current development version is available at
8486
##### Changes
8587

8688
- All feature index tuples in `SequentialFeatureSelector` or now in sorted order. ([#262](https://github.com/rasbt/mlxtend/pull/262))
87-
- The `SequentialFeatureSelector` now runs the continuation of the floating inclusion/exclusion as described in Novovicova & Kittler (1994).
88-
Note that this didn't cause any difference in performance on any of the test scenarios but could lead to better performance in certain edge cases.
89+
- The `SequentialFeatureSelector` now runs the continuation of the floating inclusion/exclusion as described in Novovicova & Kittler (1994).
90+
Note that this didn't cause any difference in performance on any of the test scenarios but could lead to better performance in certain edge cases.
8991
([#262](https://github.com/rasbt/mlxtend/pull/262))
9092
- `utils.Counter` now accepts a name variable to help distinguish between multiple counters, time precision can be set with the 'precision' kwarg and the new attribute end_time holds the time the last iteration completed. ([#278](https://github.com/rasbt/mlxtend/pull/278) via [Mathew Savage](https://github.com/matsavage))
9193

docs/sources/user_guide/plotting/plot_decision_regions.ipynb

Lines changed: 108 additions & 33 deletions
Large diffs are not rendered by default.

mlxtend/plotting/decision_regions.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from itertools import cycle
1010
import matplotlib.pyplot as plt
1111
import numpy as np
12-
from ..utils import check_Xy
12+
from ..utils import check_Xy, format_kwarg_dictionaries
1313
import warnings
1414

1515

@@ -49,7 +49,10 @@ def plot_decision_regions(X, y, clf,
4949
legend=1,
5050
hide_spines=True,
5151
markers='s^oxv<>',
52-
colors='red,blue,limegreen,gray,cyan'):
52+
colors='red,blue,limegreen,gray,cyan',
53+
scatter_kwargs=None,
54+
contourf_kwargs=None,
55+
scatter_highlight_kwargs=None):
5356
"""Plot decision regions of a classifier.
5457
5558
Please note that this functions assumes that class labels are
@@ -95,10 +98,16 @@ def plot_decision_regions(X, y, clf,
9598
legend : int (default: 1)
9699
Integer to specify the legend location.
97100
No legend if legend is 0.
98-
markers : str (default 's^oxv<>')
101+
markers : str (default: 's^oxv<>')
99102
Scatterplot markers.
100-
colors : str (default 'red,blue,limegreen,gray,cyan')
103+
colors : str (default: 'red,blue,limegreen,gray,cyan')
101104
Comma separated list of colors.
105+
scatter_kwargs : dict (default: None)
106+
Keyword arguments for underlying matplotlib scatter function.
107+
contourf_kwargs : dict (default: None)
108+
Keyword arguments for underlying matplotlib contourf function.
109+
scatter_highlight_kwargs : dict (default: None)
110+
Keyword arguments for underlying matplotlib scatter function.
102111
103112
Returns
104113
---------
@@ -204,15 +213,26 @@ def plot_decision_regions(X, y, clf,
204213
Z = clf.predict(X_predict.astype(X.dtype))
205214
Z = Z.reshape(xx.shape)
206215
# Plot decisoin region
216+
# Make sure contourf_kwargs has backwards compatible defaults
217+
contourf_kwargs_default = {'alpha': 0.3, 'antialiased': True}
218+
contourf_kwargs = format_kwarg_dictionaries(
219+
default_kwargs=contourf_kwargs_default,
220+
user_kwargs=contourf_kwargs,
221+
protected_keys=['colors', 'levels'])
207222
ax.contourf(xx, yy, Z,
208-
alpha=0.3,
209223
colors=colors,
210224
levels=np.arange(Z.max() + 2) - 0.5,
211-
antialiased=True)
225+
**contourf_kwargs)
212226

213227
ax.axis(xmin=xx.min(), xmax=xx.max(), y_min=yy.min(), y_max=yy.max())
214228

215229
# Scatter training data samples
230+
# Make sure scatter_kwargs has backwards compatible defaults
231+
scatter_kwargs_default = {'alpha': 0.8, 'edgecolor': 'black'}
232+
scatter_kwargs = format_kwarg_dictionaries(
233+
default_kwargs=scatter_kwargs_default,
234+
user_kwargs=scatter_kwargs,
235+
protected_keys=['c', 'marker', 'label'])
216236
for idx, c in enumerate(np.unique(y)):
217237
if dim == 1:
218238
y_data = [0 for i in X[y == c]]
@@ -232,11 +252,10 @@ def plot_decision_regions(X, y, clf,
232252

233253
ax.scatter(x=x_data,
234254
y=y_data,
235-
alpha=0.8,
236255
c=colors[idx],
237256
marker=next(marker_gen),
238-
edgecolor='black',
239-
label=c)
257+
label=c,
258+
**scatter_kwargs)
240259

241260
if hide_spines:
242261
ax.spines['right'].set_visible(False)
@@ -248,14 +267,6 @@ def plot_decision_regions(X, y, clf,
248267
if dim == 1:
249268
ax.axes.get_yaxis().set_ticks([])
250269

251-
if legend:
252-
if dim > 2 and filler_feature_ranges is None:
253-
pass
254-
else:
255-
handles, labels = ax.get_legend_handles_labels()
256-
ax.legend(handles, labels,
257-
framealpha=0.3, scatterpoints=1, loc=legend)
258-
259270
if plot_testdata:
260271
if dim == 1:
261272
x_data = X_highlight
@@ -270,13 +281,26 @@ def plot_decision_regions(X, y, clf,
270281
y_data = X_highlight[feature_range_mask, y_index]
271282
x_data = X_highlight[feature_range_mask, x_index]
272283

284+
# Make sure scatter_highlight_kwargs backwards compatible defaults
285+
scatter_highlight_defaults = {'c': '',
286+
'edgecolor': 'black',
287+
'alpha': 1.0,
288+
'linewidths': 1,
289+
'marker': 'o',
290+
's': 80}
291+
scatter_highlight_kwargs = format_kwarg_dictionaries(
292+
default_kwargs=scatter_highlight_defaults,
293+
user_kwargs=scatter_highlight_kwargs)
273294
ax.scatter(x_data,
274295
y_data,
275-
c='',
276-
edgecolor='black',
277-
alpha=1.0,
278-
linewidths=1,
279-
marker='o',
280-
s=80)
296+
**scatter_highlight_kwargs)
297+
298+
if legend:
299+
if dim > 2 and filler_feature_ranges is None:
300+
pass
301+
else:
302+
handles, labels = ax.get_legend_handles_labels()
303+
ax.legend(handles, labels,
304+
framealpha=0.3, scatterpoints=1, loc=legend)
281305

282306
return ax

mlxtend/plotting/tests/test_decision_regions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,37 @@ def test_y_ary_dim():
7979
'y must be a 1D array',
8080
plot_decision_regions,
8181
X[:, :2], y[:, np.newaxis], sr)
82+
83+
84+
def test_scatter_kwargs_type():
85+
kwargs = 'not a dictionary'
86+
sr.fit(X[:, :2], y)
87+
message = ('d must be of type dict or None, but got '
88+
'{} instead'.format(type(kwargs)))
89+
assert_raises(TypeError,
90+
message,
91+
plot_decision_regions,
92+
X[:, :2], y, sr, scatter_kwargs=kwargs)
93+
94+
95+
def test_contourf_kwargs_type():
96+
kwargs = 'not a dictionary'
97+
sr.fit(X[:, :2], y)
98+
message = ('d must be of type dict or None, but got '
99+
'{} instead'.format(type(kwargs)))
100+
assert_raises(TypeError,
101+
message,
102+
plot_decision_regions,
103+
X[:, :2], y, sr, contourf_kwargs=kwargs)
104+
105+
106+
def test_scatter_highlight_kwargs_type():
107+
kwargs = 'not a dictionary'
108+
sr.fit(X[:, :2], y)
109+
message = ('d must be of type dict or None, but got '
110+
'{} instead'.format(type(kwargs)))
111+
assert_raises(TypeError,
112+
message,
113+
plot_decision_regions,
114+
X[:, :2], y, sr, X_highlight=X[:, :2],
115+
scatter_highlight_kwargs=kwargs)

mlxtend/utils/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .counter import Counter
88
from .testing import assert_raises
9-
from .checking import check_Xy
9+
from .checking import check_Xy, format_kwarg_dictionaries
1010

11-
__all__ = ["Counter", "assert_raises", "check_Xy"]
11+
__all__ = ["Counter", "assert_raises", "check_Xy",
12+
"format_kwarg_dictionaries"]

mlxtend/utils/checking.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,36 @@ def check_Xy(X, y, y_int=True):
3636
if y.shape[0] != X.shape[0]:
3737
raise ValueError('y and X must contain the same number of samples. '
3838
'Got y: %d, X: %d' % (y.shape[0], X.shape[0]))
39+
40+
41+
def format_kwarg_dictionaries(default_kwargs=None, user_kwargs=None,
42+
protected_keys=None):
43+
"""Function to combine default and user specified kwargs dictionaries
44+
45+
Parameters
46+
----------
47+
default_kwargs : dict, optional
48+
Default kwargs (default is None).
49+
user_kwargs : dict, optional
50+
User specified kwargs (default is None).
51+
protected_keys : array_like, optional
52+
Sequence of keys to be removed from the returned dictionary
53+
(default is None).
54+
55+
Returns
56+
-------
57+
formatted_kwargs : dict
58+
Formatted kwargs dictionary.
59+
"""
60+
formatted_kwargs = {}
61+
for d in [default_kwargs, user_kwargs]:
62+
if not isinstance(d, (dict, type(None))):
63+
raise TypeError('d must be of type dict or None, but '
64+
'got {} instead'.format(type(d)))
65+
if d is not None:
66+
formatted_kwargs.update(d)
67+
if protected_keys is not None:
68+
for key in protected_keys:
69+
formatted_kwargs.pop(key, None)
70+
71+
return formatted_kwargs

mlxtend/utils/tests/test_checking_inputs.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
# License: BSD 3 clause
66

77
from mlxtend.utils import assert_raises
8-
from mlxtend.utils import check_Xy
8+
from mlxtend.utils import check_Xy, format_kwarg_dictionaries
99
import numpy as np
1010
import sys
1111
import os
1212

1313
y = np.array([1, 2, 3, 4])
1414
X = np.array([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
1515

16+
d_default = {'key1': 1, 'key2': 2}
17+
d_user = {'key3': 3, 'key4': 4}
18+
protected_keys = ['key1', 'key4']
1619

17-
def test_ok():
20+
21+
def test_check_Xy_ok():
1822
check_Xy(X, y)
1923

2024

21-
def test_invalid_type_X():
25+
def test_check_Xy_invalid_type_X():
2226
expect = "X must be a NumPy array. Found <class 'list'>"
2327
if (sys.version_info < (3, 0)):
2428
expect = expect.replace('class', 'type')
@@ -29,15 +33,15 @@ def test_invalid_type_X():
2933
y)
3034

3135

32-
def test_float16_X():
36+
def test_check_Xy_float16_X():
3337
check_Xy(X.astype(np.float16), y)
3438

3539

36-
def test_float16_y():
40+
def test_check_Xy_float16_y():
3741
check_Xy(X, y.astype(np.int16))
3842

3943

40-
def test_invalid_type_y():
44+
def test_check_Xy_invalid_type_y():
4145
expect = "y must be a NumPy array. Found <class 'list'>"
4246
if (sys.version_info < (3, 0)):
4347
expect = expect.replace('class', 'type')
@@ -48,15 +52,15 @@ def test_invalid_type_y():
4852
[1, 2, 3, 4])
4953

5054

51-
def test_invalid_dtype_X():
55+
def test_check_Xy_invalid_dtype_X():
5256
assert_raises(ValueError,
5357
'X must be an integer or float array. Found object.',
5458
check_Xy,
5559
X.astype('object'),
5660
y)
5761

5862

59-
def test_invalid_dtype_y():
63+
def test_check_Xy_invalid_dtype_y():
6064

6165
if (sys.version_info > (3, 0)):
6266
expect = ('y must be an integer array. Found <U1. '
@@ -71,7 +75,7 @@ def test_invalid_dtype_y():
7175
np.array(['a', 'b', 'c', 'd']))
7276

7377

74-
def test_invalid_dim_y():
78+
def test_check_Xy_invalid_dim_y():
7579
if sys.version_info[:2] == (2, 7) and os.name == 'nt':
7680
s = 'y must be a 1D array. Found (4L, 2L)'
7781
else:
@@ -83,7 +87,7 @@ def test_invalid_dim_y():
8387
X.astype(np.integer))
8488

8589

86-
def test_invalid_dim_X():
90+
def test_check_Xy_invalid_dim_X():
8791
if sys.version_info[:2] == (2, 7) and os.name == 'nt':
8892
s = 'X must be a 2D array. Found (4L,)'
8993
else:
@@ -95,7 +99,7 @@ def test_invalid_dim_X():
9599
y)
96100

97101

98-
def test_unequal_length_X():
102+
def test_check_Xy_unequal_length_X():
99103
assert_raises(ValueError,
100104
('y and X must contain the same number of samples. '
101105
'Got y: 4, X: 3'),
@@ -104,10 +108,56 @@ def test_unequal_length_X():
104108
y)
105109

106110

107-
def test_unequal_length_y():
111+
def test_check_Xy_unequal_length_y():
108112
assert_raises(ValueError,
109113
('y and X must contain the same number of samples. '
110114
'Got y: 3, X: 4'),
111115
check_Xy,
112116
X,
113117
y[1:])
118+
119+
120+
def test_format_kwarg_dictionaries_defaults_empty():
121+
empty = format_kwarg_dictionaries()
122+
assert isinstance(empty, dict)
123+
assert len(empty) == 0
124+
125+
126+
def test_format_kwarg_dictionaries_protected_keys():
127+
formatted_kwargs = format_kwarg_dictionaries(
128+
default_kwargs=d_default,
129+
user_kwargs=d_user,
130+
protected_keys=protected_keys)
131+
132+
for key in protected_keys:
133+
assert key not in formatted_kwargs
134+
135+
136+
def test_format_kwarg_dictionaries_no_default_kwargs():
137+
formatted_kwargs = format_kwarg_dictionaries(user_kwargs=d_user)
138+
assert formatted_kwargs == d_user
139+
140+
141+
def test_format_kwarg_dictionaries_no_user_kwargs():
142+
formatted_kwargs = format_kwarg_dictionaries(default_kwargs=d_default)
143+
assert formatted_kwargs == d_default
144+
145+
146+
def test_format_kwarg_dictionaries_default_kwargs_invalid_type():
147+
invalid_kwargs = 'not a dictionary'
148+
message = ('d must be of type dict or None, but got '
149+
'{} instead'.format(type(invalid_kwargs)))
150+
assert_raises(TypeError,
151+
message,
152+
format_kwarg_dictionaries,
153+
default_kwargs=invalid_kwargs)
154+
155+
156+
def test_format_kwarg_dictionaries_user_kwargs_invalid_type():
157+
invalid_kwargs = 'not a dictionary'
158+
message = ('d must be of type dict or None, but got '
159+
'{} instead'.format(type(invalid_kwargs)))
160+
assert_raises(TypeError,
161+
message,
162+
format_kwarg_dictionaries,
163+
user_kwargs=invalid_kwargs)

0 commit comments

Comments
 (0)