Skip to content

Commit 388999b

Browse files
authored
API Deprecate positional arguments in tree module (scikit-learn#16966)
1 parent 1523f39 commit 388999b

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

sklearn/tree/_classes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ..utils import compute_sample_weight
3737
from ..utils.multiclass import check_classification_targets
3838
from ..utils.validation import check_is_fitted
39+
from ..utils.validation import _deprecate_positional_args
3940

4041
from ._criterion import Criterion
4142
from ._splitter import Splitter
@@ -82,7 +83,8 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
8283
"""
8384

8485
@abstractmethod
85-
def __init__(self,
86+
@_deprecate_positional_args
87+
def __init__(self, *,
8688
criterion,
8789
splitter,
8890
max_depth,
@@ -815,7 +817,8 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
815817
array([ 1. , 0.93..., 0.86..., 0.93..., 0.93...,
816818
0.93..., 0.93..., 1. , 0.93..., 1. ])
817819
"""
818-
def __init__(self,
820+
@_deprecate_positional_args
821+
def __init__(self, *,
819822
criterion="gini",
820823
splitter="best",
821824
max_depth=None,
@@ -1169,7 +1172,8 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
11691172
array([-0.39..., -0.46..., 0.02..., 0.06..., -0.50...,
11701173
0.16..., 0.11..., -0.73..., -0.30..., -0.00...])
11711174
"""
1172-
def __init__(self,
1175+
@_deprecate_positional_args
1176+
def __init__(self, *,
11731177
criterion="mse",
11741178
splitter="best",
11751179
max_depth=None,
@@ -1499,7 +1503,8 @@ class ExtraTreeClassifier(DecisionTreeClassifier):
14991503
>>> cls.score(X_test, y_test)
15001504
0.8947...
15011505
"""
1502-
def __init__(self,
1506+
@_deprecate_positional_args
1507+
def __init__(self, *,
15031508
criterion="gini",
15041509
splitter="random",
15051510
max_depth=None,
@@ -1716,7 +1721,8 @@ class ExtraTreeRegressor(DecisionTreeRegressor):
17161721
>>> reg.score(X_test, y_test)
17171722
0.33...
17181723
"""
1719-
def __init__(self,
1724+
@_deprecate_positional_args
1725+
def __init__(self, *,
17201726
criterion="mse",
17211727
splitter="random",
17221728
max_depth=None,

sklearn/tree/_export.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818

1919
from ..utils.validation import check_is_fitted
20+
from ..utils.validation import _deprecate_positional_args
2021
from ..base import is_classifier
2122

2223
from . import _criterion
@@ -77,7 +78,8 @@ def __repr__(self):
7778
SENTINEL = Sentinel()
7879

7980

80-
def plot_tree(decision_tree, max_depth=None, feature_names=None,
81+
@_deprecate_positional_args
82+
def plot_tree(decision_tree, *, max_depth=None, feature_names=None,
8183
class_names=None, label='all', filled=False,
8284
impurity=True, node_ids=False,
8385
proportion=False, rotate='deprecated', rounded=False,
@@ -656,7 +658,8 @@ def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
656658
ax.annotate("\n (...) \n", xy_parent, xy, **kwargs)
657659

658660

659-
def export_graphviz(decision_tree, out_file=None, max_depth=None,
661+
@_deprecate_positional_args
662+
def export_graphviz(decision_tree, out_file=None, *, max_depth=None,
660663
feature_names=None, class_names=None, label='all',
661664
filled=False, leaves_parallel=False, impurity=True,
662665
node_ids=False, proportion=False, rotate=False,
@@ -807,7 +810,8 @@ def compute_depth_(current_node, current_depth,
807810
return max(depths)
808811

809812

810-
def export_text(decision_tree, feature_names=None, max_depth=10,
813+
@_deprecate_positional_args
814+
def export_text(decision_tree, *, feature_names=None, max_depth=10,
811815
spacing=3, decimals=2, show_weights=False):
812816
"""Build a text report showing the rules of a decision tree.
813817

sklearn/tree/tests/test_export.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,5 @@ def test_not_fitted_tree(pyplot):
465465

466466
# Testing if not fitted tree throws the correct error
467467
clf = DecisionTreeRegressor()
468-
out = StringIO()
469468
with pytest.raises(NotFittedError):
470-
plot_tree(clf, out)
469+
plot_tree(clf)

0 commit comments

Comments
 (0)