Skip to content

Commit f37277a

Browse files
qinhanmin2014TomDLT
authored andcommitted
MNT Clean up plot_tree (remove matplotlib < 1.5) (scikit-learn#14321)
1 parent 4ff6de8 commit f37277a

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

sklearn/tree/export.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def plot_tree(decision_tree, max_depth=None, feature_names=None,
8585
8686
The sample counts that are shown are weighted with any sample_weights that
8787
might be present.
88-
This function requires matplotlib, and works best with matplotlib >= 1.5.
8988
9089
The visualization is fit automatically to the size of the axis.
9190
Use the ``figsize`` or ``dpi`` arguments of ``plt.figure`` to control
@@ -541,9 +540,6 @@ def __init__(self, max_depth=None, feature_names=None,
541540
self.bbox_args = dict(fc='w')
542541
if self.rounded:
543542
self.bbox_args['boxstyle'] = "round"
544-
else:
545-
# matplotlib <1.5 requires explicit boxstyle
546-
self.bbox_args['boxstyle'] = "square"
547543

548544
self.arrow_args = dict(arrowstyle="<-")
549545

@@ -599,27 +595,20 @@ def export(self, decision_tree, ax=None):
599595
# get figure to data transform
600596
# adjust fontsize to avoid overlap
601597
# get max box width and height
602-
try:
603-
extents = [ann.get_bbox_patch().get_window_extent()
604-
for ann in anns]
605-
max_width = max([extent.width for extent in extents])
606-
max_height = max([extent.height for extent in extents])
607-
# width should be around scale_x in axis coordinates
608-
size = anns[0].get_fontsize() * min(scale_x / max_width,
609-
scale_y / max_height)
610-
for ann in anns:
611-
ann.set_fontsize(size)
612-
except AttributeError:
613-
# matplotlib < 1.5
614-
warnings.warn("Automatic scaling of tree plots requires "
615-
"matplotlib 1.5 or higher. Please specify "
616-
"fontsize.")
598+
extents = [ann.get_bbox_patch().get_window_extent()
599+
for ann in anns]
600+
max_width = max([extent.width for extent in extents])
601+
max_height = max([extent.height for extent in extents])
602+
# width should be around scale_x in axis coordinates
603+
size = anns[0].get_fontsize() * min(scale_x / max_width,
604+
scale_y / max_height)
605+
for ann in anns:
606+
ann.set_fontsize(size)
617607

618608
return anns
619609

620610
def recurse(self, node, tree, ax, scale_x, scale_y, height, depth=0):
621-
# need to copy bbox args because matplotib <1.5 modifies them
622-
kwargs = dict(bbox=self.bbox_args.copy(), ha='center', va='center',
611+
kwargs = dict(bbox=self.bbox_args, ha='center', va='center',
623612
zorder=100 - 10 * depth, xycoords='axes pixels')
624613

625614
if self.fontsize is not None:

0 commit comments

Comments
 (0)