3131from sklearn .base import (clone , ClassifierMixin , RegressorMixin ,
3232 TransformerMixin , ClusterMixin )
3333from sklearn .utils import shuffle
34- from sklearn .preprocessing import StandardScaler , Scaler
34+ from sklearn .preprocessing import StandardScaler
3535from sklearn .datasets import (load_iris , load_boston , make_blobs ,
3636 make_classification )
3737from sklearn .metrics import accuracy_score , adjusted_rand_score , f1_score
3838
3939from sklearn .lda import LDA
4040from sklearn .svm .base import BaseLibSVM
4141
42- # import "special" estimators
43- from sklearn .pls import _PLS , PLSCanonical , PLSRegression , CCA , PLSSVD
44- from sklearn .feature_selection import SelectKBest
45- from sklearn .naive_bayes import MultinomialNB , BernoulliNB
46- from sklearn .kernel_approximation import AdditiveChi2Sampler
47- from sklearn .preprocessing import Binarizer , Normalizer
48- from sklearn .cluster import (WardAgglomeration , AffinityPropagation ,
49- SpectralClustering )
50- from sklearn .random_projection import (GaussianRandomProjection ,
51- SparseRandomProjection )
52-
5342from sklearn .cross_validation import train_test_split
5443
5544dont_test = ['SparseCoder' , 'EllipticEnvelope' , 'EllipticEnvelop' ,
@@ -170,7 +159,7 @@ def test_transformers():
170159 if name in dont_test :
171160 continue
172161 # these don't actually fit the data:
173- if Trans in [AdditiveChi2Sampler , Binarizer , Normalizer ]:
162+ if name in [' AdditiveChi2Sampler' , ' Binarizer' , ' Normalizer' ]:
174163 continue
175164 # catch deprecation warnings
176165 with warnings .catch_warnings (record = True ):
@@ -179,12 +168,12 @@ def test_transformers():
179168 if hasattr (trans , 'compute_importances' ):
180169 trans .compute_importances = True
181170
182- if Trans is SelectKBest :
171+ if name == ' SelectKBest' :
183172 # SelectKBest has a default of k=10
184173 # which is more feature than we have.
185174 trans .k = 1
186- elif Trans in [GaussianRandomProjection ,
187- SparseRandomProjection ]:
175+ elif name in [' GaussianRandomProjection' ,
176+ ' SparseRandomProjection' ]:
188177 # Due to the jl lemma and very few samples, the number
189178 # of components of the random matrix projection will be greater
190179 # than the number of features.
@@ -193,7 +182,7 @@ def test_transformers():
193182
194183 # fit
195184
196- if Trans in (_PLS , PLSCanonical , PLSRegression , CCA , PLSSVD ):
185+ if name in (' _PLS' , ' PLSCanonical' , ' PLSRegression' , ' CCA' , ' PLSSVD' ):
197186 random_state = np .random .RandomState (seed = 12345 )
198187 y_ = np .vstack ([y , 2 * y + random_state .randint (2 , size = len (y ))])
199188 y_ = y_ .T
@@ -216,7 +205,8 @@ def test_transformers():
216205 continue
217206
218207 if hasattr (trans , 'transform' ):
219- if Trans in (_PLS , PLSCanonical , PLSRegression , CCA , PLSSVD ):
208+ if name in ('_PLS' , 'PLSCanonical' , 'PLSRegression' , 'CCA' ,
209+ 'PLSSVD' ):
220210 X_pred2 = trans .transform (X , y_ )
221211 X_pred3 = trans .fit_transform (X , y = y_ )
222212 else :
@@ -257,10 +247,10 @@ def test_transformers_sparse_data():
257247 continue
258248 # catch deprecation warnings
259249 with warnings .catch_warnings (record = True ):
260- if Trans in [Scaler , StandardScaler ]:
250+ if name in [' Scaler' , ' StandardScaler' ]:
261251 trans = Trans (with_mean = False )
262- elif Trans in [GaussianRandomProjection ,
263- SparseRandomProjection ]:
252+ elif name in [' GaussianRandomProjection' ,
253+ ' SparseRandomProjection' ]:
264254 # Due to the jl lemma and very few samples, the number
265255 # of components of the random matrix projection will be greater
266256 # than the number of features.
@@ -309,14 +299,15 @@ def test_estimators_nan_inf():
309299 for name , Est in estimators :
310300 if name in dont_test :
311301 continue
312- if Est in (_PLS , PLSCanonical , PLSRegression , CCA , PLSSVD ):
302+ if name in ('_PLS' , 'PLSCanonical' , 'PLSRegression' , 'CCA' ,
303+ 'PLSSVD' ):
313304 continue
314305
315306 # catch deprecation warnings
316307 with warnings .catch_warnings (record = True ):
317308 est = Est ()
318- if Est in [GaussianRandomProjection ,
319- SparseRandomProjection ]:
309+ if name in [' GaussianRandomProjection' ,
310+ ' SparseRandomProjection' ]:
320311 # Due to the jl lemma and very few samples, the number
321312 # of components of the random matrix projection will be
322313 # greater
@@ -430,7 +421,7 @@ def test_clustering():
430421 n_samples , n_features = X .shape
431422 X = StandardScaler ().fit_transform (X )
432423 for name , Alg in clustering :
433- if Alg is WardAgglomeration :
424+ if name == ' WardAgglomeration' :
434425 # this is clustering on the features
435426 # let's not test that here.
436427 continue
@@ -440,7 +431,7 @@ def test_clustering():
440431 if hasattr (alg , "n_clusters" ):
441432 alg .set_params (n_clusters = 3 )
442433 set_random_state (alg )
443- if Alg is AffinityPropagation :
434+ if name == ' AffinityPropagation' :
444435 alg .set_params (preference = - 100 )
445436 # fit
446437 alg .fit (X )
@@ -449,7 +440,7 @@ def test_clustering():
449440 pred = alg .labels_
450441 assert_greater (adjusted_rand_score (pred , y ), 0.4 )
451442 # fit another time with ``fit_predict`` and compare results
452- if Alg is SpectralClustering :
443+ if name is ' SpectralClustering' :
453444 # there is no way to make Spectral clustering deterministic :(
454445 continue
455446 set_random_state (alg )
@@ -476,7 +467,7 @@ def test_classifiers_train():
476467 for name , Clf in classifiers :
477468 if name in dont_test :
478469 continue
479- if Clf in [MultinomialNB , BernoulliNB ]:
470+ if name in [' MultinomialNB' , ' BernoulliNB' ]:
480471 # TODO also test these!
481472 continue
482473 # catch deprecation warnings
@@ -544,7 +535,7 @@ def test_classifiers_classes():
544535 for name , Clf in classifiers :
545536 if name in dont_test :
546537 continue
547- if Clf in [MultinomialNB , BernoulliNB ]:
538+ if name in [' MultinomialNB' , ' BernoulliNB' ]:
548539 # TODO also test these!
549540 continue
550541
@@ -573,7 +564,7 @@ def test_regressors_int():
573564 X = StandardScaler ().fit_transform (X )
574565 y = np .random .randint (2 , size = X .shape [0 ])
575566 for name , Reg in regressors :
576- if name in dont_test or Reg in (CCA ,):
567+ if name in dont_test or name in (' CCA' ,):
577568 continue
578569 # catch deprecation warnings
579570 with warnings .catch_warnings (record = True ):
@@ -583,7 +574,7 @@ def test_regressors_int():
583574 set_random_state (reg1 )
584575 set_random_state (reg2 )
585576
586- if Reg in (_PLS , PLSCanonical , PLSRegression ):
577+ if name in (' _PLS' , ' PLSCanonical' , ' PLSRegression' ):
587578 y_ = np .vstack ([y , 2 * y + np .random .randint (2 , size = len (y ))])
588579 y_ = y_ .T
589580 else :
@@ -621,15 +612,15 @@ def test_regressors_train():
621612 assert_raises (ValueError , reg .fit , X , y [:- 1 ])
622613 # fit
623614 try :
624- if Reg in (_PLS , PLSCanonical , PLSRegression , CCA ):
615+ if name in (' _PLS' , ' PLSCanonical' , ' PLSRegression' , ' CCA' ):
625616 y_ = np .vstack ([y , 2 * y + np .random .randint (2 , size = len (y ))])
626617 y_ = y_ .T
627618 else :
628619 y_ = y
629620 reg .fit (X , y_ )
630621 reg .predict (X )
631622
632- if Reg not in (PLSCanonical , CCA ): # TODO: find out why
623+ if name not in (' PLSCanonical' , ' CCA' ): # TODO: find out why
633624 assert_greater (reg .score (X , y_ ), 0.5 )
634625 except Exception as e :
635626 print (reg )
@@ -769,8 +760,8 @@ def test_estimators_overwrite_params():
769760 # for MiniBatchDictLearning
770761 est .batch_size = 1
771762
772- if Est in [GaussianRandomProjection ,
773- SparseRandomProjection ]:
763+ if name in [' GaussianRandomProjection' ,
764+ ' SparseRandomProjection' ]:
774765 # Due to the jl lemma and very few samples, the number
775766 # of components of the random matrix projection will be
776767 # greater
0 commit comments