Skip to content

Commit 186858e

Browse files
committed
ENH get rid of most imports in test_common
1 parent 2a47d8d commit 186858e

File tree

1 file changed

+26
-35
lines changed

1 file changed

+26
-35
lines changed

sklearn/tests/test_common.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,14 @@
3131
from sklearn.base import (clone, ClassifierMixin, RegressorMixin,
3232
TransformerMixin, ClusterMixin)
3333
from sklearn.utils import shuffle
34-
from sklearn.preprocessing import StandardScaler, Scaler
34+
from sklearn.preprocessing import StandardScaler
3535
from sklearn.datasets import (load_iris, load_boston, make_blobs,
3636
make_classification)
3737
from sklearn.metrics import accuracy_score, adjusted_rand_score, f1_score
3838

3939
from sklearn.lda import LDA
4040
from sklearn.svm.base import BaseLibSVM
4141

42-
# import "special" estimators
43-
from sklearn.pls import _PLS, PLSCanonical, PLSRegression, CCA, PLSSVD
44-
from sklearn.feature_selection import SelectKBest
45-
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
46-
from sklearn.kernel_approximation import AdditiveChi2Sampler
47-
from sklearn.preprocessing import Binarizer, Normalizer
48-
from sklearn.cluster import (WardAgglomeration, AffinityPropagation,
49-
SpectralClustering)
50-
from sklearn.random_projection import (GaussianRandomProjection,
51-
SparseRandomProjection)
52-
5342
from sklearn.cross_validation import train_test_split
5443

5544
dont_test = ['SparseCoder', 'EllipticEnvelope', 'EllipticEnvelop',
@@ -170,7 +159,7 @@ def test_transformers():
170159
if name in dont_test:
171160
continue
172161
# these don't actually fit the data:
173-
if Trans in [AdditiveChi2Sampler, Binarizer, Normalizer]:
162+
if name in ['AdditiveChi2Sampler', 'Binarizer', 'Normalizer']:
174163
continue
175164
# catch deprecation warnings
176165
with warnings.catch_warnings(record=True):
@@ -179,12 +168,12 @@ def test_transformers():
179168
if hasattr(trans, 'compute_importances'):
180169
trans.compute_importances = True
181170

182-
if Trans is SelectKBest:
171+
if name == 'SelectKBest':
183172
# SelectKBest has a default of k=10
184173
# which is more feature than we have.
185174
trans.k = 1
186-
elif Trans in [GaussianRandomProjection,
187-
SparseRandomProjection]:
175+
elif name in ['GaussianRandomProjection',
176+
'SparseRandomProjection']:
188177
# Due to the jl lemma and very few samples, the number
189178
# of components of the random matrix projection will be greater
190179
# than the number of features.
@@ -193,7 +182,7 @@ def test_transformers():
193182

194183
# fit
195184

196-
if Trans in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
185+
if name in ('_PLS', 'PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'):
197186
random_state = np.random.RandomState(seed=12345)
198187
y_ = np.vstack([y, 2 * y + random_state.randint(2, size=len(y))])
199188
y_ = y_.T
@@ -216,7 +205,8 @@ def test_transformers():
216205
continue
217206

218207
if hasattr(trans, 'transform'):
219-
if Trans in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
208+
if name in ('_PLS', 'PLSCanonical', 'PLSRegression', 'CCA',
209+
'PLSSVD'):
220210
X_pred2 = trans.transform(X, y_)
221211
X_pred3 = trans.fit_transform(X, y=y_)
222212
else:
@@ -257,10 +247,10 @@ def test_transformers_sparse_data():
257247
continue
258248
# catch deprecation warnings
259249
with warnings.catch_warnings(record=True):
260-
if Trans in [Scaler, StandardScaler]:
250+
if name in ['Scaler', 'StandardScaler']:
261251
trans = Trans(with_mean=False)
262-
elif Trans in [GaussianRandomProjection,
263-
SparseRandomProjection]:
252+
elif name in ['GaussianRandomProjection',
253+
'SparseRandomProjection']:
264254
# Due to the jl lemma and very few samples, the number
265255
# of components of the random matrix projection will be greater
266256
# than the number of features.
@@ -309,14 +299,15 @@ def test_estimators_nan_inf():
309299
for name, Est in estimators:
310300
if name in dont_test:
311301
continue
312-
if Est in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
302+
if name in ('_PLS', 'PLSCanonical', 'PLSRegression', 'CCA',
303+
'PLSSVD'):
313304
continue
314305

315306
# catch deprecation warnings
316307
with warnings.catch_warnings(record=True):
317308
est = Est()
318-
if Est in [GaussianRandomProjection,
319-
SparseRandomProjection]:
309+
if name in ['GaussianRandomProjection',
310+
'SparseRandomProjection']:
320311
# Due to the jl lemma and very few samples, the number
321312
# of components of the random matrix projection will be
322313
# greater
@@ -430,7 +421,7 @@ def test_clustering():
430421
n_samples, n_features = X.shape
431422
X = StandardScaler().fit_transform(X)
432423
for name, Alg in clustering:
433-
if Alg is WardAgglomeration:
424+
if name == 'WardAgglomeration':
434425
# this is clustering on the features
435426
# let's not test that here.
436427
continue
@@ -440,7 +431,7 @@ def test_clustering():
440431
if hasattr(alg, "n_clusters"):
441432
alg.set_params(n_clusters=3)
442433
set_random_state(alg)
443-
if Alg is AffinityPropagation:
434+
if name == 'AffinityPropagation':
444435
alg.set_params(preference=-100)
445436
# fit
446437
alg.fit(X)
@@ -449,7 +440,7 @@ def test_clustering():
449440
pred = alg.labels_
450441
assert_greater(adjusted_rand_score(pred, y), 0.4)
451442
# fit another time with ``fit_predict`` and compare results
452-
if Alg is SpectralClustering:
443+
if name is 'SpectralClustering':
453444
# there is no way to make Spectral clustering deterministic :(
454445
continue
455446
set_random_state(alg)
@@ -476,7 +467,7 @@ def test_classifiers_train():
476467
for name, Clf in classifiers:
477468
if name in dont_test:
478469
continue
479-
if Clf in [MultinomialNB, BernoulliNB]:
470+
if name in ['MultinomialNB', 'BernoulliNB']:
480471
# TODO also test these!
481472
continue
482473
# catch deprecation warnings
@@ -544,7 +535,7 @@ def test_classifiers_classes():
544535
for name, Clf in classifiers:
545536
if name in dont_test:
546537
continue
547-
if Clf in [MultinomialNB, BernoulliNB]:
538+
if name in ['MultinomialNB', 'BernoulliNB']:
548539
# TODO also test these!
549540
continue
550541

@@ -573,7 +564,7 @@ def test_regressors_int():
573564
X = StandardScaler().fit_transform(X)
574565
y = np.random.randint(2, size=X.shape[0])
575566
for name, Reg in regressors:
576-
if name in dont_test or Reg in (CCA,):
567+
if name in dont_test or name in ('CCA',):
577568
continue
578569
# catch deprecation warnings
579570
with warnings.catch_warnings(record=True):
@@ -583,7 +574,7 @@ def test_regressors_int():
583574
set_random_state(reg1)
584575
set_random_state(reg2)
585576

586-
if Reg in (_PLS, PLSCanonical, PLSRegression):
577+
if name in ('_PLS', 'PLSCanonical', 'PLSRegression'):
587578
y_ = np.vstack([y, 2 * y + np.random.randint(2, size=len(y))])
588579
y_ = y_.T
589580
else:
@@ -621,15 +612,15 @@ def test_regressors_train():
621612
assert_raises(ValueError, reg.fit, X, y[:-1])
622613
# fit
623614
try:
624-
if Reg in (_PLS, PLSCanonical, PLSRegression, CCA):
615+
if name in ('_PLS', 'PLSCanonical', 'PLSRegression', 'CCA'):
625616
y_ = np.vstack([y, 2 * y + np.random.randint(2, size=len(y))])
626617
y_ = y_.T
627618
else:
628619
y_ = y
629620
reg.fit(X, y_)
630621
reg.predict(X)
631622

632-
if Reg not in (PLSCanonical, CCA): # TODO: find out why
623+
if name not in ('PLSCanonical', 'CCA'): # TODO: find out why
633624
assert_greater(reg.score(X, y_), 0.5)
634625
except Exception as e:
635626
print(reg)
@@ -769,8 +760,8 @@ def test_estimators_overwrite_params():
769760
# for MiniBatchDictLearning
770761
est.batch_size = 1
771762

772-
if Est in [GaussianRandomProjection,
773-
SparseRandomProjection]:
763+
if name in ['GaussianRandomProjection',
764+
'SparseRandomProjection']:
774765
# Due to the jl lemma and very few samples, the number
775766
# of components of the random matrix projection will be
776767
# greater

0 commit comments

Comments
 (0)