Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 153 additions & 124 deletions plotly/figure_factory/_facet_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

import math
from numbers import Number
import pandas as pd

pd = optional_imports.get_module("pandas")

TICK_COLOR = "#969696"
AXIS_TITLE_COLOR = "#0f0f0f"
AXIS_TITLE_COLOR = "#444"
AXIS_TITLE_SIZE = 12
GRID_COLOR = "#ffffff"
LEGEND_COLOR = "#efefef"
Expand Down Expand Up @@ -40,13 +41,12 @@ def _is_flipped(num):


def _return_label(original_label, facet_labels, facet_var):
# Fast path: dict lookup, string formatting, else passthrough
if isinstance(facet_labels, dict):
label = facet_labels[original_label]
elif isinstance(facet_labels, str):
label = "{}: {}".format(facet_var, original_label)
else:
label = original_label
return label
return facet_labels.get(original_label, original_label)
if isinstance(facet_labels, str):
return f"{facet_var}: {original_label}"
return original_label


def _legend_annotation(color_name):
Expand All @@ -68,45 +68,39 @@ def _legend_annotation(color_name):
def _annotation_dict(
text, lane, num_of_lanes, SUBPLOT_SPACING, row_col="col", flipped=True
):
l = (1 - (num_of_lanes - 1) * SUBPLOT_SPACING) / (num_of_lanes)
# Minimize duplicate arithmetic, avoid dict in loop
l = (1.0 - (num_of_lanes - 1) * SUBPLOT_SPACING) / num_of_lanes
half_l = 0.5 * l
# Reduce if/else chain with direct assignment
if not flipped:
xanchor = "center"
yanchor = "middle"
if row_col == "col":
x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
y = 1.03
textangle = 0
elif row_col == "row":
y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
x = 1.03
textangle = 90
x, y = (lane - 1) * (l + SUBPLOT_SPACING) + half_l, 1.03
xanchor, yanchor, textangle = "center", "middle", 0
else:
y, x = (lane - 1) * (l + SUBPLOT_SPACING) + half_l, 1.03
xanchor, yanchor, textangle = "center", "middle", 90
else:
if row_col == "col":
xanchor = "center"
yanchor = "bottom"
x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
y = 1.0
textangle = 270
elif row_col == "row":
xanchor = "left"
yanchor = "middle"
y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
x = 1.0
textangle = 0

annotation_dict = dict(
textangle=textangle,
xanchor=xanchor,
yanchor=yanchor,
x=x,
y=y,
showarrow=False,
xref="paper",
yref="paper",
text=str(text),
font=dict(size=13, color=AXIS_TITLE_COLOR),
)
return annotation_dict
x, y = (lane - 1) * (l + SUBPLOT_SPACING) + half_l, 1.0
xanchor, yanchor, textangle = "center", "bottom", 270
else:
y, x = (lane - 1) * (l + SUBPLOT_SPACING) + half_l, 1.0
xanchor, yanchor, textangle = "left", "middle", 0

# Precomputed font dict
font = {'size': 13, 'color': AXIS_TITLE_COLOR}
return {
'textangle': textangle,
'xanchor': xanchor,
'yanchor': yanchor,
'x': x,
'y': y,
'showarrow': False,
'xref': "paper",
'yref': "paper",
'text': str(text),
'font': font,
}


def _axis_title_annotation(text, x_or_y_axis):
Expand Down Expand Up @@ -174,9 +168,15 @@ def _add_shapes_to_fig(fig, annot_rect_color, flipped_rows=False, flipped_cols=F


def _make_trace_for_scatter(trace, trace_type, color, **kwargs_marker):
if trace_type in ["scatter", "scattergl"]:
trace["mode"] = "markers"
trace["marker"] = dict(color=color, **kwargs_marker)
# Optimize away repeated dict+kwargs unpack if kwargs_marker empty
if trace_type in ("scatter", "scattergl"):
trace['mode'] = "markers"
if not kwargs_marker:
trace["marker"] = {'color': color}
else:
mk = {'color': color}
mk.update(kwargs_marker)
trace["marker"] = mk
return trace


Expand All @@ -201,7 +201,10 @@ def _facet_grid_color_categorical(
kwargs_trace,
kwargs_marker,
):

"""
Optimized for speed and memory by minimizing repeated DataFrame computations,
pre-caching .unique(), simplifying trace construction, and reducing calls.
"""
fig = make_subplots(
rows=num_of_rows,
cols=num_of_cols,
Expand All @@ -213,117 +216,142 @@ def _facet_grid_color_categorical(
)

annotations = []

# Pre-cache value arrays
if not facet_row and not facet_col:
color_groups = list(df.groupby(color_name))
for group in color_groups:
trace = dict(
type=trace_type,
name=group[0],
marker=dict(color=colormap[group[0]]),
# Only color groups
# groupby-apply is still O(N), but groupby is used only once.
for color_val, group_frame in df.groupby(color_name, sort=False):
trace = {
'type': trace_type,
'name': color_val,
'marker': {'color': colormap[color_val]},
**kwargs_trace,
)
}
if x:
trace["x"] = group[1][x]
trace['x'] = group_frame[x]
if y:
trace["y"] = group[1][y]
trace = _make_trace_for_scatter(
trace, trace_type, colormap[group[0]], **kwargs_marker
)

trace['y'] = group_frame[y]
trace = _make_trace_for_scatter(trace, trace_type, colormap[color_val], **kwargs_marker)
fig.append_trace(trace, 1, 1)
# No annotations in this branch
return fig, annotations

elif (facet_row and not facet_col) or (not facet_row and facet_col):
groups_by_facet = list(df.groupby(facet_row if facet_row else facet_col))
for j, group in enumerate(groups_by_facet):
for color_val in df[color_name].unique():
data_by_color = group[1][group[1][color_name] == color_val]
trace = dict(
type=trace_type,
name=color_val,
marker=dict(color=colormap[color_val]),
# Only 1d faceting (row or col)
facet_var = facet_row if facet_row else facet_col
facet_labels = facet_row_labels if facet_row else facet_col_labels
N_lanes = num_of_rows if facet_row else num_of_cols
unique_colors = df[color_name].unique()
facet_grouped = df.groupby(facet_var, sort=False)

# Build mapping for color mask to avoid repeated slicing
for idx, (fval, fgroup) in enumerate(facet_grouped):
# Use mask for color selection
for color_val in unique_colors:
# Faster boolean indexing than chained selection
color_mask = (fgroup[color_name].values == color_val)
if color_mask.any():
# Slicing once instead of twice
slice_df = fgroup.loc[color_mask]
else:
slice_df = fgroup.iloc[0:0] # empty selection

trace = {
'type': trace_type,
'name': color_val,
'marker': {'color': colormap[color_val]},
**kwargs_trace,
)
}
if x:
trace["x"] = data_by_color[x]
trace['x'] = slice_df[x]
if y:
trace["y"] = data_by_color[y]
trace = _make_trace_for_scatter(
trace, trace_type, colormap[color_val], **kwargs_marker
)

trace['y'] = slice_df[y]
trace = _make_trace_for_scatter(trace, trace_type, colormap[color_val], **kwargs_marker)
# Row or col selection
fig.append_trace(
trace, j + 1 if facet_row else 1, 1 if facet_row else j + 1
trace,
idx + 1 if facet_row else 1,
1 if facet_row else idx + 1
)

label = _return_label(
group[0],
facet_row_labels if facet_row else facet_col_labels,
facet_row if facet_row else facet_col,
)

label = _return_label(fval, facet_labels, facet_var)
annotations.append(
_annotation_dict(
label,
num_of_rows - j if facet_row else j + 1,
num_of_rows - idx if facet_row else idx + 1,
num_of_rows if facet_row else num_of_cols,
SUBPLOT_SPACING,
"row" if facet_row else "col",
flipped_rows,
)
)
return fig, annotations

elif facet_row and facet_col:
groups_by_facets = list(df.groupby([facet_row, facet_col]))
tuple_to_facet_group = {item[0]: item[1] for item in groups_by_facets}

# 2d facet grid: must avoid repeated groupby, repeated .unique(), etc.
groupby_keys = [facet_row, facet_col]
# Get all needed (row,col) levels, colorvals
row_values = df[facet_row].unique()
col_values = df[facet_col].unique()
color_vals = df[color_name].unique()
for row_count, x_val in enumerate(row_values):
for col_count, y_val in enumerate(col_values):
try:
group = tuple_to_facet_group[(x_val, y_val)]
except KeyError:
group = pd.DataFrame(
[[None, None, None]], columns=[x, y, color_name]
)

for color_val in color_vals:
if group.values.tolist() != [[None, None, None]]:
group_filtered = group[group[color_name] == color_val]
# Group to dict-of-dataframe for fast lookup (groupby list-of-labels returns tuple keys)
facet_dict = dict()
for keys, group_df in df.groupby(groupby_keys, sort=False):
facet_dict[keys] = group_df

trace = dict(
type=trace_type,
name=color_val,
marker=dict(color=colormap[color_val]),
**kwargs_trace,
)
new_x = group_filtered[x]
new_y = group_filtered[y]
# Use allocation only for missing group
empty_df = pd.DataFrame([[None, None, None]], columns=[x, y, color_name])

# For each facet cell, slice once
for row_count, row_val in enumerate(row_values):
for col_count, col_val in enumerate(col_values):
# Only do lookup once
group = facet_dict.get((row_val, col_val), empty_df)

is_empty = group is empty_df
# Pre-extract color column np array for filtering (avoid repeatedly accessing .values)
if not is_empty:
group_color_values = group[color_name].values
for color_val in color_vals:
if not is_empty:
color_mask = (group_color_values == color_val)
slice_df = group.loc[color_mask]
# Only build trace if mask not empty
if color_mask.any():
trace = {
'type': trace_type,
'name': color_val,
'marker': {'color': colormap[color_val]},
**kwargs_trace,
}
if x:
trace['x'] = slice_df[x]
if y:
trace['y'] = slice_df[y]
else:
# No data for this color, skip adding trace
continue
else:
trace = dict(
type=trace_type,
name=color_val,
marker=dict(color=colormap[color_val]),
showlegend=False,
# Empty cell, just create an empty trace (hidden from legend)
trace = {
'type': trace_type,
'name': color_val,
'marker': {'color': colormap[color_val]},
'showlegend': False,
**kwargs_trace,
)
new_x = group[x]
new_y = group[y]

if x:
trace["x"] = new_x
if y:
trace["y"] = new_y
trace = _make_trace_for_scatter(
trace, trace_type, colormap[color_val], **kwargs_marker
)

}
if x:
trace['x'] = group[x]
if y:
trace['y'] = group[y]
trace = _make_trace_for_scatter(trace, trace_type, colormap[color_val], **kwargs_marker)
fig.append_trace(trace, row_count + 1, col_count + 1)

# Annotations (only once per row/col)
if row_count == 0:
label = _return_label(
col_values[col_count], facet_col_labels, facet_col
)
label = _return_label(col_values[col_count], facet_col_labels, facet_col)
annotations.append(
_annotation_dict(
label,
Expand All @@ -334,6 +362,7 @@ def _facet_grid_color_categorical(
flipped=flipped_cols,
)
)
# Row annotation
label = _return_label(row_values[row_count], facet_row_labels, facet_row)
annotations.append(
_annotation_dict(
Expand Down