|
21 | 21 | from numpy.testing import assert_array_equal |
22 | 22 | from numpy.testing import assert_raises |
23 | 23 |
|
| 24 | +from collections import defaultdict, Mapping |
| 25 | +from functools import partial |
24 | 26 | import pickle |
25 | 27 | from StringIO import StringIO |
26 | 28 |
|
| 29 | + |
27 | 30 | JUNK_FOOD_DOCS = ( |
28 | 31 | "the pizza pizza beer copyright", |
29 | 32 | "the pizza burger beer copyright", |
@@ -189,20 +192,20 @@ def test_char_ngram_analyzer(): |
189 | 192 |
|
190 | 193 |
|
191 | 194 | def test_countvectorizer_custom_vocabulary(): |
192 | | - what_we_like = ["pizza", "beer"] |
193 | | - vect = CountVectorizer(vocabulary=what_we_like) |
194 | | - vect.fit(JUNK_FOOD_DOCS) |
195 | | - assert_equal(set(vect.vocabulary_), set(what_we_like)) |
196 | | - X = vect.transform(JUNK_FOOD_DOCS) |
197 | | - assert_equal(X.shape[1], len(what_we_like)) |
198 | | - |
199 | | - # try again with a dict vocabulary |
200 | 195 | vocab = {"pizza": 0, "beer": 1} |
201 | | - vect = CountVectorizer(vocabulary=vocab) |
202 | | - vect.fit(JUNK_FOOD_DOCS) |
203 | | - assert_equal(vect.vocabulary_, vocab) |
204 | | - X = vect.transform(JUNK_FOOD_DOCS) |
205 | | - assert_equal(X.shape[1], len(what_we_like)) |
| 196 | + terms = set(vocab.keys()) |
| 197 | + |
| 198 | + # Try a few of the supported types. |
| 199 | + for typ in [dict, list, iter, partial(defaultdict, int)]: |
| 200 | + v = typ(vocab) |
| 201 | + vect = CountVectorizer(vocabulary=v) |
| 202 | + vect.fit(JUNK_FOOD_DOCS) |
| 203 | + if isinstance(v, Mapping): |
| 204 | + assert_equal(vect.vocabulary_, vocab) |
| 205 | + else: |
| 206 | + assert_equal(set(vect.vocabulary_), terms) |
| 207 | + X = vect.transform(JUNK_FOOD_DOCS) |
| 208 | + assert_equal(X.shape[1], len(terms)) |
206 | 209 |
|
207 | 210 |
|
208 | 211 | def test_countvectorizer_custom_vocabulary_pipeline(): |
|
0 commit comments