Skip to content

Commit ed932a7

Browse files
Mao771danielhomola
andauthored
Bugfix/max depth (scikit-learn-contrib#83)
* Fixed KeyError if max_depth does not present in estimator parameters * Typo fixed * Update boruta_py.py Co-authored-by: Daniel Homola <[email protected]>
1 parent fa6d659 commit ed932a7

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

boruta/boruta_py.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import scipy as sp
1414
from sklearn.utils import check_random_state, check_X_y
1515
from sklearn.base import TransformerMixin, BaseEstimator
16+
import warnings
1617

1718

1819
class BorutaPy(BaseEstimator, TransformerMixin):
@@ -401,7 +402,14 @@ def _transform(self, X, weak=False, return_df=False):
401402
return X
402403

403404
def _get_tree_num(self, n_feat):
404-
depth = self.estimator.get_params()['max_depth']
405+
depth = None
406+
try:
407+
depth = self.estimator.get_params()['max_depth']
408+
except KeyError:
409+
warnings.warn(
410+
"The estimator does not have a max_depth property, as a result "
411+
" the number of trees to use cannot be estimated automatically."
412+
)
405413
if depth == None:
406414
depth = 10
407415
# how many times a feature should be considered on average

0 commit comments

Comments
 (0)