Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 46 additions & 49 deletions plotly/_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@


def _get_initial_max_subplot_ids():
max_subplot_ids = {subplot_type: 0 for subplot_type in _single_subplot_types}
max_subplot_ids["xaxis"] = 0
max_subplot_ids["yaxis"] = 0
return max_subplot_ids
# Use a static template dictionary to avoid repeated dict creation
# This is read-only, so copy on each call
d = dict.fromkeys(_single_subplot_types, 0)
d["xaxis"] = 0
d["yaxis"] = 0
return d


def make_subplots(
Expand Down Expand Up @@ -969,53 +971,44 @@ def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=No
if max_subplot_ids is None:
max_subplot_ids = _get_initial_max_subplot_ids()

# Get axis label and anchor
x_cnt = max_subplot_ids["xaxis"] + 1
y_cnt = max_subplot_ids["yaxis"] + 1

# Compute x/y labels (the values of trace.xaxis/trace.yaxis
x_label = "x{cnt}".format(cnt=x_cnt if x_cnt > 1 else "")
y_label = "y{cnt}".format(cnt=y_cnt if y_cnt > 1 else "")

# Anchor x and y axes to each other
# Use fast f-string formatting
x_label = f"x{x_cnt}" if x_cnt > 1 else "x"
y_label = f"y{y_cnt}" if y_cnt > 1 else "y"
x_anchor, y_anchor = y_label, x_label
xaxis_name = f"xaxis{x_cnt}" if x_cnt > 1 else "xaxis"
yaxis_name = f"yaxis{y_cnt}" if y_cnt > 1 else "yaxis"

# Build layout.xaxis/layout.yaxis containers
xaxis_name = "xaxis{cnt}".format(cnt=x_cnt if x_cnt > 1 else "")
yaxis_name = "yaxis{cnt}".format(cnt=y_cnt if y_cnt > 1 else "")
x_axis = {"domain": x_domain, "anchor": x_anchor}
y_axis = {"domain": y_domain, "anchor": y_anchor}

layout[xaxis_name] = x_axis
layout[yaxis_name] = y_axis

subplot_refs = [
SubplotRef(
subplot_type="xy",
layout_keys=(xaxis_name, yaxis_name),
trace_kwargs={"xaxis": x_label, "yaxis": y_label},
)
]
ref0 = SubplotRef(
subplot_type="xy",
layout_keys=(xaxis_name, yaxis_name),
trace_kwargs={"xaxis": x_label, "yaxis": y_label},
)

subplot_refs = [ref0]

if secondary_y:
y_cnt += 1
secondary_yaxis_name = "yaxis{cnt}".format(cnt=y_cnt if y_cnt > 1 else "")
secondary_y_label = "y{cnt}".format(cnt=y_cnt)

# Add secondary y-axis to subplot reference
subplot_refs.append(
SubplotRef(
subplot_type="xy",
layout_keys=(xaxis_name, secondary_yaxis_name),
trace_kwargs={"xaxis": x_label, "yaxis": secondary_y_label},
)
y_cnt_sec = y_cnt + 1
secondary_yaxis_name = f"yaxis{y_cnt_sec}" if y_cnt_sec > 1 else "yaxis"
secondary_y_label = f"y{y_cnt_sec}"
ref1 = SubplotRef(
subplot_type="xy",
layout_keys=(xaxis_name, secondary_yaxis_name),
trace_kwargs={"xaxis": x_label, "yaxis": secondary_y_label},
)

# Add secondary y axis to layout
subplot_refs.append(ref1)
secondary_y_axis = {"anchor": y_anchor, "overlaying": y_label, "side": "right"}
layout[secondary_yaxis_name] = secondary_y_axis

# increment max_subplot_ids
y_cnt = y_cnt_sec # update counter if secondary_y

max_subplot_ids["xaxis"] = x_cnt
max_subplot_ids["yaxis"] = y_cnt

Expand All @@ -1028,11 +1021,13 @@ def _init_subplot_single(
if max_subplot_ids is None:
max_subplot_ids = _get_initial_max_subplot_ids()

# Add scene to layout
cnt = max_subplot_ids[subplot_type] + 1
label = "{subplot_type}{cnt}".format(
subplot_type=subplot_type, cnt=cnt if cnt > 1 else ""
)
if cnt > 1:
label = f"{subplot_type}{cnt}"
else:
label = f"{subplot_type}"

# Use tuple directly for keys
scene = dict(domain={"x": x_domain, "y": y_domain})
layout[label] = scene

Expand All @@ -1044,7 +1039,6 @@ def _init_subplot_single(
subplot_type=subplot_type, layout_keys=(label,), trace_kwargs={trace_key: label}
)

# increment max_subplot_id
max_subplot_ids[subplot_type] = cnt

return (subplot_ref,)
Expand Down Expand Up @@ -1088,21 +1082,21 @@ def _subplot_type_for_trace_type(trace_type):


def _validate_coerce_subplot_type(subplot_type):
# Lowercase subplot_type
# Lowercase once
orig_subplot_type = subplot_type
subplot_type = subplot_type.lower()
subplot_type_lc = subplot_type.lower()

# Check if it's a named subplot type
if subplot_type in _subplot_types:
return subplot_type
if subplot_type_lc in _subplot_types:
return subplot_type_lc

# Try to determine subplot type for trace
subplot_type = _subplot_type_for_trace_type(subplot_type)
subplot_type_val = _subplot_type_for_trace_type(subplot_type_lc)

if subplot_type is None:
if subplot_type_val is None:
raise ValueError("Unsupported subplot type: {}".format(repr(orig_subplot_type)))
else:
return subplot_type
return subplot_type_val


def _init_subplot(
Expand All @@ -1117,8 +1111,11 @@ def _init_subplot(
# Clamp domain elements between [0, 1].
# This is only needed to combat numerical precision errors
# See GH1031
x_domain = [max(0.0, x_domain[0]), min(1.0, x_domain[1])]
y_domain = [max(0.0, y_domain[0]), min(1.0, y_domain[1])]
# Directly build the list: this is very hot, just return the minimum/maximum
x0, x1 = x_domain
y0, y1 = y_domain
x_domain = [max(0.0, x0), min(1.0, x1)]
y_domain = [max(0.0, y0), min(1.0, y1)]

if subplot_type == "xy":
subplot_refs = _init_subplot_xy(
Expand Down