Skip to content

Commit 4161d55

Browse files
committed
Improving Bunch Class to ensure consistent attributes
Adding set/getattr methods that fill/query the same thing as `bunch[key]`. Add test for a non-regression bug in fetch_20newsgroups.
1 parent 0fe613e commit 4161d55

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

sklearn/datasets/base.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,35 @@
2525

2626

2727
class Bunch(dict):
28-
"""Container object for datasets: dictionary-like object that
29-
exposes its keys as attributes."""
28+
"""Container object for datasets
29+
30+
Dictionary-like object that exposes its keys as attributes.
31+
32+
>>> b = Bunch(a=1, b=2)
33+
>>> b['b']
34+
2
35+
>>> b.b
36+
2
37+
>>> b.a = 3
38+
>>> b['a']
39+
3
40+
>>> b.c = 6
41+
>>> b['c']
42+
6
43+
44+
"""
3045

3146
def __init__(self, **kwargs):
3247
dict.__init__(self, kwargs)
33-
self.__dict__ = self
48+
49+
def __setattr__(self, key, value):
50+
self[key] = value
51+
52+
def __getattr__(self, key):
53+
return self[key]
54+
55+
def __getstate__(self):
56+
return self.__dict__
3457

3558

3659
def get_data_home(data_home=None):

sklearn/datasets/tests/test_20news.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ def test_20news():
3939
assert_equal(entry1, entry2)
4040

4141

42+
def test_20news_length_consistency():
43+
"""Checks the length consistencies within the bunch
44+
45+
This is a non-regression test for a bug present in 0.16.1.
46+
"""
47+
try:
48+
data = datasets.fetch_20newsgroups(
49+
subset='all', download_if_missing=False, shuffle=False)
50+
except IOError:
51+
raise SkipTest("Download 20 newsgroups to run this test")
52+
# Extract the full dataset
53+
data = datasets.fetch_20newsgroups(subset='all')
54+
assert_equal(len(data['data']), len(data.data))
55+
assert_equal(len(data['target']), len(data.target))
56+
assert_equal(len(data['filenames']), len(data.filenames))
57+
58+
4259
def test_20news_vectorized():
4360
# This test is slow.
4461
raise SkipTest("Test too slow.")

sklearn/datasets/twenty_newsgroups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def fetch_20newsgroups(data_home=None, subset='train', categories=None,
161161
for the test set, 'all' for both, with shuffled ordering.
162162
163163
data_home: optional, default: None
164-
Specify an download and cache folder for the datasets. If None,
164+
Specify a download and cache folder for the datasets. If None,
165165
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
166166
167167
categories: None or collection of string or unicode

0 commit comments

Comments
 (0)