@@ -236,6 +236,32 @@ def build_analyzer(self):
236236 raise ValueError ('%s is not a valid tokenization scheme/analyzer' %
237237 self .analyzer )
238238
239+ def _check_vocabulary (self ):
240+ vocabulary = self .vocabulary
241+ if vocabulary is not None :
242+ if not isinstance (vocabulary , Mapping ):
243+ vocab = {}
244+ for i , t in enumerate (vocabulary ):
245+ if vocab .setdefault (t , i ) != i :
246+ msg = "Duplicate term in vocabulary: %r" % t
247+ raise ValueError (msg )
248+ vocabulary = vocab
249+ else :
250+ indices = set (six .itervalues (vocabulary ))
251+ if len (indices ) != len (vocabulary ):
252+ raise ValueError ("Vocabulary contains repeated indices." )
253+ for i in xrange (len (vocabulary )):
254+ if i not in indices :
255+ msg = ("Vocabulary of size %d doesn't contain index "
256+ "%d." % (len (vocabulary ), i ))
257+ raise ValueError (msg )
258+ if not vocabulary :
259+ raise ValueError ("empty vocabulary passed to fit" )
260+ self .fixed_vocabulary = True
261+ self .vocabulary_ = dict (vocabulary )
262+ else :
263+ self .fixed_vocabulary = False
264+
239265
240266class HashingVectorizer (BaseEstimator , VectorizerMixin ):
241267 """Convert a collection of text documents to a matrix of token occurrences
@@ -616,29 +642,7 @@ def __init__(self, input='content', encoding='utf-8',
616642 "max_features=%r, neither a positive integer nor None"
617643 % max_features )
618644 self .ngram_range = ngram_range
619- if vocabulary is not None :
620- if not isinstance (vocabulary , Mapping ):
621- vocab = {}
622- for i , t in enumerate (vocabulary ):
623- if vocab .setdefault (t , i ) != i :
624- msg = "Duplicate term in vocabulary: %r" % t
625- raise ValueError (msg )
626- vocabulary = vocab
627- else :
628- indices = set (six .itervalues (vocabulary ))
629- if len (indices ) != len (vocabulary ):
630- raise ValueError ("Vocabulary contains repeated indices." )
631- for i in xrange (len (vocabulary )):
632- if i not in indices :
633- msg = ("Vocabulary of size %d doesn't contain index "
634- "%d." % (len (vocabulary ), i ))
635- raise ValueError (msg )
636- if not vocabulary :
637- raise ValueError ("empty vocabulary passed to fit" )
638- self .fixed_vocabulary = True
639- self .vocabulary_ = dict (vocabulary )
640- else :
641- self .fixed_vocabulary = False
645+ self .vocabulary = vocabulary
642646 self .binary = binary
643647 self .dtype = dtype
644648
@@ -773,6 +777,7 @@ def fit_transform(self, raw_documents, y=None):
773777 # We intentionally don't call the transform method to make
774778 # fit_transform overridable without unwanted side effects in
775779 # TfidfVectorizer.
780+ self ._check_vocabulary ()
776781 max_df = self .max_df
777782 min_df = self .min_df
778783 max_features = self .max_features
@@ -820,6 +825,9 @@ def transform(self, raw_documents):
820825 X : sparse matrix, [n_samples, n_features]
821826 Document-term matrix.
822827 """
828+ if not hasattr (self , 'vocabulary_' ):
829+ self ._check_vocabulary ()
830+
823831 if not hasattr (self , 'vocabulary_' ) or len (self .vocabulary_ ) == 0 :
824832 raise ValueError ("Vocabulary wasn't fitted or is empty!" )
825833
0 commit comments