Skip to content

Commit f4882b5

Browse files
committed
DOC + TST vocabulary arg in CountVect docstring
Somewhere during the last refactoring, the documentation for the argument went missing. Also, check for Mapping subclass instead of hasattr "get" and test with a few different types.
1 parent 1601b27 commit f4882b5

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

sklearn/feature_extraction/tests/test_text.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
from numpy.testing import assert_array_equal
2222
from numpy.testing import assert_raises
2323

24+
from collections import defaultdict, Mapping
25+
from functools import partial
2426
import pickle
2527
from StringIO import StringIO
2628

29+
2730
JUNK_FOOD_DOCS = (
2831
"the pizza pizza beer copyright",
2932
"the pizza burger beer copyright",
@@ -189,20 +192,20 @@ def test_char_ngram_analyzer():
189192

190193

191194
def test_countvectorizer_custom_vocabulary():
192-
what_we_like = ["pizza", "beer"]
193-
vect = CountVectorizer(vocabulary=what_we_like)
194-
vect.fit(JUNK_FOOD_DOCS)
195-
assert_equal(set(vect.vocabulary_), set(what_we_like))
196-
X = vect.transform(JUNK_FOOD_DOCS)
197-
assert_equal(X.shape[1], len(what_we_like))
198-
199-
# try again with a dict vocabulary
200195
vocab = {"pizza": 0, "beer": 1}
201-
vect = CountVectorizer(vocabulary=vocab)
202-
vect.fit(JUNK_FOOD_DOCS)
203-
assert_equal(vect.vocabulary_, vocab)
204-
X = vect.transform(JUNK_FOOD_DOCS)
205-
assert_equal(X.shape[1], len(what_we_like))
196+
terms = set(vocab.keys())
197+
198+
# Try a few of the supported types.
199+
for typ in [dict, list, iter, partial(defaultdict, int)]:
200+
v = typ(vocab)
201+
vect = CountVectorizer(vocabulary=v)
202+
vect.fit(JUNK_FOOD_DOCS)
203+
if isinstance(v, Mapping):
204+
assert_equal(vect.vocabulary_, vocab)
205+
else:
206+
assert_equal(set(vect.vocabulary_), terms)
207+
X = vect.transform(JUNK_FOOD_DOCS)
208+
assert_equal(X.shape[1], len(terms))
206209

207210

208211
def test_countvectorizer_custom_vocabulary_pipeline():

sklearn/feature_extraction/text.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
build feature vectors from text documents.
1111
"""
1212

13+
from collections import Mapping
14+
from operator import itemgetter
1315
import re
1416
import unicodedata
15-
from operator import itemgetter
1617
import warnings
1718

1819
import numpy as np
@@ -167,6 +168,11 @@ class CountVectorizer(BaseEstimator):
167168
168169
This parameter is ignored if vocabulary is not None.
169170
171+
vocabulary: Mapping or iterable, optional
172+
Either a Mapping (e.g., a dict) where keys are terms and values are
173+
indices in the feature matrix, or an iterable over terms. If not
174+
given, a vocabulary is determined from the input documents.
175+
170176
binary: boolean, False by default.
171177
If True, all non zero counts are set to 1. This is useful for discrete
172178
probabilistic models that model binary events rather than integer
@@ -201,7 +207,7 @@ def __init__(self, input='content', charset='utf-8',
201207
self.max_features = max_features
202208
if vocabulary is not None:
203209
self.fixed_vocabulary = True
204-
if not hasattr(vocabulary, 'get'):
210+
if not isinstance(vocabulary, Mapping):
205211
vocabulary = dict((t, i) for i, t in enumerate(vocabulary))
206212
self.vocabulary_ = vocabulary
207213
else:

0 commit comments

Comments
 (0)