Skip to content

Commit 70fd42e

Browse files
adrinjalalijnothman
authored andcommitted
FIX make sure vectorizers read data from file before analyzing (scikit-learn#13641)
1 parent 38bff6d commit 70fd42e

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

doc/whats_new/v0.21.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ random sampling procedures.
2727
- :class:`linear_model.LogisticRegression` and
2828
:class:`linear_model.LogisticRegressionCV` with 'saga' solver. |Fix|
2929
- :class:`ensemble.GradientBoostingClassifier` |Fix|
30+
- :class:`sklearn.feature_extraction.text.HashingVectorizer`,
31+
:class:`sklearn.feature_extraction.text.TfidfVectorizer`, and
32+
:class:`sklearn.feature_extraction.text.CountVectorizer` |API|
3033
- :class:`neural_network.MLPClassifier` |Fix|
3134
- :func:`svm.SVC.decision_function` and
3235
:func:`multiclass.OneVsOneClassifier.decision_function`. |Fix|
@@ -265,6 +268,17 @@ Support for Python 3.4 and below has been officially dropped.
265268
- |API| Deprecated :mod:`externals.six` since we have dropped support for
266269
Python 2.7. :issue:`12916` by :user:`Hanmin Qin <qinhanmin2014>`.
267270

271+
:mod:`sklearn.feature_extraction`
272+
.................................
273+
274+
- |API| If ``input='file'`` or ``input='filename'``, and a callable is given
275+
as the ``analyzer``, :class:`sklearn.feature_extraction.text.HashingVectorizer`,
276+
:class:`sklearn.feature_extraction.text.TfidfVectorizer`, and
277+
:class:`sklearn.feature_extraction.text.CountVectorizer` now read the data
278+
from the file(s) and then pass it to the given ``analyzer``, instead of
279+
passing the file name(s) or the file object(s) to the analyzer.
280+
:issue:`13641` by `Adrin Jalali`_.
281+
268282
:mod:`sklearn.impute`
269283
.....................
270284

sklearn/feature_extraction/tests/test_text.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from numpy.testing import assert_array_almost_equal
3030
from numpy.testing import assert_array_equal
3131
from sklearn.utils import IS_PYPY
32+
from sklearn.exceptions import ChangedBehaviorWarning
3233
from sklearn.utils.testing import (assert_equal, assert_not_equal,
3334
assert_almost_equal, assert_in,
3435
assert_less, assert_greater,
@@ -1196,3 +1197,47 @@ def build_preprocessor(self):
11961197
.findall(doc),
11971198
stop_words=['and'])
11981199
assert _check_stop_words_consistency(vec) is True
1200+
1201+
1202+
@pytest.mark.parametrize('Estimator',
1203+
[CountVectorizer, TfidfVectorizer, HashingVectorizer])
1204+
@pytest.mark.parametrize(
1205+
'input_type, err_type, err_msg',
1206+
[('filename', FileNotFoundError, ''),
1207+
('file', AttributeError, "'str' object has no attribute 'read'")]
1208+
)
1209+
def test_callable_analyzer_error(Estimator, input_type, err_type, err_msg):
1210+
data = ['this is text, not file or filename']
1211+
with pytest.raises(err_type, match=err_msg):
1212+
Estimator(analyzer=lambda x: x.split(),
1213+
input=input_type).fit_transform(data)
1214+
1215+
1216+
@pytest.mark.parametrize('Estimator',
1217+
[CountVectorizer, TfidfVectorizer, HashingVectorizer])
1218+
@pytest.mark.parametrize(
1219+
'analyzer', [lambda doc: open(doc, 'r'), lambda doc: doc.read()]
1220+
)
1221+
@pytest.mark.parametrize('input_type', ['file', 'filename'])
1222+
def test_callable_analyzer_change_behavior(Estimator, analyzer, input_type):
1223+
data = ['this is text, not file or filename']
1224+
warn_msg = 'Since v0.21, vectorizer'
1225+
with pytest.raises((FileNotFoundError, AttributeError)):
1226+
with pytest.warns(ChangedBehaviorWarning, match=warn_msg) as records:
1227+
Estimator(analyzer=analyzer, input=input_type).fit_transform(data)
1228+
assert len(records) == 1
1229+
assert warn_msg in str(records[0])
1230+
1231+
1232+
@pytest.mark.parametrize('Estimator',
1233+
[CountVectorizer, TfidfVectorizer, HashingVectorizer])
1234+
def test_callable_analyzer_reraise_error(tmpdir, Estimator):
1235+
# check if a custom exception from the analyzer is shown to the user
1236+
def analyzer(doc):
1237+
raise Exception("testing")
1238+
1239+
f = tmpdir.join("file.txt")
1240+
f.write("sample content\n")
1241+
1242+
with pytest.raises(Exception, match="testing"):
1243+
Estimator(analyzer=analyzer, input='file').fit_transform([f])

sklearn/feature_extraction/text.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES
3232
from ..utils import _IS_32BIT
3333
from ..utils.fixes import _astype_copy_false
34+
from ..exceptions import ChangedBehaviorWarning
3435

3536

3637
__all__ = ['HashingVectorizer',
@@ -304,10 +305,34 @@ def _check_stop_words_consistency(self, stop_words, preprocess, tokenize):
304305
self._stop_words_id = id(self.stop_words)
305306
return 'error'
306307

308+
def _validate_custom_analyzer(self):
309+
# This is to check if the given custom analyzer expects file or a
310+
# filename instead of data.
311+
# Behavior changed in v0.21, function could be removed in v0.23
312+
import tempfile
313+
with tempfile.NamedTemporaryFile() as f:
314+
fname = f.name
315+
# now we're sure fname doesn't exist
316+
317+
msg = ("Since v0.21, vectorizers pass the data to the custom analyzer "
318+
"and not the file names or the file objects. This warning "
319+
"will be removed in v0.23.")
320+
try:
321+
self.analyzer(fname)
322+
except FileNotFoundError:
323+
warnings.warn(msg, ChangedBehaviorWarning)
324+
except AttributeError as e:
325+
if str(e) == "'str' object has no attribute 'read'":
326+
warnings.warn(msg, ChangedBehaviorWarning)
327+
except Exception:
328+
pass
329+
307330
def build_analyzer(self):
308331
"""Return a callable that handles preprocessing and tokenization"""
309332
if callable(self.analyzer):
310-
return self.analyzer
333+
if self.input in ['file', 'filename']:
334+
self._validate_custom_analyzer()
335+
return lambda doc: self.analyzer(self.decode(doc))
311336

312337
preprocess = self.build_preprocessor()
313338

@@ -490,6 +515,11 @@ class HashingVectorizer(BaseEstimator, VectorizerMixin, TransformerMixin):
490515
If a callable is passed it is used to extract the sequence of features
491516
out of the raw, unprocessed input.
492517
518+
.. versionchanged:: 0.21
519+
Since v0.21, if ``input`` is ``filename`` or ``file``, the data is
520+
first read from the file and then passed to the given callable
521+
analyzer.
522+
493523
n_features : integer, default=(2 ** 20)
494524
The number of features (columns) in the output matrices. Small numbers
495525
of features are likely to cause hash collisions, but large numbers
@@ -745,6 +775,11 @@ class CountVectorizer(BaseEstimator, VectorizerMixin):
745775
If a callable is passed it is used to extract the sequence of features
746776
out of the raw, unprocessed input.
747777
778+
.. versionchanged:: 0.21
779+
Since v0.21, if ``input`` is ``filename`` or ``file``, the data is
780+
first read from the file and then passed to the given callable
781+
analyzer.
782+
748783
max_df : float in range [0.0, 1.0] or int, default=1.0
749784
When building the vocabulary ignore terms that have a document
750785
frequency strictly higher than the given threshold (corpus-specific
@@ -1369,6 +1404,11 @@ class TfidfVectorizer(CountVectorizer):
13691404
If a callable is passed it is used to extract the sequence of features
13701405
out of the raw, unprocessed input.
13711406
1407+
.. versionchanged:: 0.21
1408+
Since v0.21, if ``input`` is ``filename`` or ``file``, the data is
1409+
first read from the file and then passed to the given callable
1410+
analyzer.
1411+
13721412
stop_words : string {'english'}, list, or None (default=None)
13731413
If a string, it is passed to _check_stop_list and the appropriate stop
13741414
list is returned. 'english' is currently the only supported string

0 commit comments

Comments
 (0)