diff --git a/packages/python/plotly/_plotly_utils/basevalidators.py b/packages/python/plotly/_plotly_utils/basevalidators.py index 8a629237c6d..814e29209e6 100644 --- a/packages/python/plotly/_plotly_utils/basevalidators.py +++ b/packages/python/plotly/_plotly_utils/basevalidators.py @@ -51,83 +51,6 @@ def to_scalar_or_list(v): return v -plotlyjsShortTypes = { - "int8": "i1", - "uint8": "u1", - "int16": "i2", - "uint16": "u2", - "int32": "i4", - "uint32": "u4", - "float32": "f4", - "float64": "f8", -} - -int8min = -128 -int8max = 127 -int16min = -32768 -int16max = 32767 -int32min = -2147483648 -int32max = 2147483647 - -uint8max = 255 -uint16max = 65535 -uint32max = 4294967295 - - -def to_typed_array_spec(v): - """ - Convert numpy array to plotly.js typed array spec - If not possible return the original value - """ - v = copy_to_readonly_numpy_array(v) - - np = get_module("numpy", should_load=False) - if not isinstance(v, np.ndarray): - return v - - dtype = str(v.dtype) - - # convert default Big Ints until we could support them in plotly.js - if dtype == "int64": - max = v.max() - min = v.min() - if max <= int8max and min >= int8min: - v = v.astype("int8") - elif max <= int16max and min >= int16min: - v = v.astype("int16") - elif max <= int32max and min >= int32min: - v = v.astype("int32") - else: - return v - - elif dtype == "uint64": - max = v.max() - min = v.min() - if max <= uint8max and min >= 0: - v = v.astype("uint8") - elif max <= uint16max and min >= 0: - v = v.astype("uint16") - elif max <= uint32max and min >= 0: - v = v.astype("uint32") - else: - return v - - dtype = str(v.dtype) - - if dtype in plotlyjsShortTypes: - arrObj = { - "dtype": plotlyjsShortTypes[dtype], - "bdata": base64.b64encode(v).decode("ascii"), - } - - if v.ndim > 1: - arrObj["shape"] = str(v.shape)[1:-1] - - return arrObj - - return v - - def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False): """ Convert an array-like value into a read-only numpy array @@ -292,15 +215,6 @@ def is_typed_array_spec(v): return isinstance(v, dict) and "bdata" in v and "dtype" in v -def has_skipped_key(all_parent_keys): - """ - Return whether any keys in the parent hierarchy are in the list of keys that - are skipped for conversion to the typed array spec - """ - skipped_keys = ["geojson", "layer", "range"] - return any(skipped_key in all_parent_keys for skipped_key in skipped_keys) - - def is_none_or_typed_array_spec(v): return v is None or is_typed_array_spec(v) @@ -500,10 +414,8 @@ def description(self): def validate_coerce(self, v): if is_none_or_typed_array_spec(v): pass - elif has_skipped_key(self.parent_name): - v = to_scalar_or_list(v) elif is_homogeneous_array(v): - v = to_typed_array_spec(v) + v = copy_to_readonly_numpy_array(v) elif is_simple_array(v): v = to_scalar_or_list(v) else: diff --git a/packages/python/plotly/_plotly_utils/utils.py b/packages/python/plotly/_plotly_utils/utils.py index e8a32e0c8ae..7d690dcb941 100644 --- a/packages/python/plotly/_plotly_utils/utils.py +++ b/packages/python/plotly/_plotly_utils/utils.py @@ -1,3 +1,4 @@ +import base64 import decimal import json as _json import sys @@ -5,7 +6,111 @@ from functools import reduce from _plotly_utils.optional_imports import get_module -from _plotly_utils.basevalidators import ImageUriValidator +from _plotly_utils.basevalidators import ( + ImageUriValidator, + copy_to_readonly_numpy_array, + is_homogeneous_array, +) + + +int8min = -128 +int8max = 127 +int16min = -32768 +int16max = 32767 +int32min = -2147483648 +int32max = 2147483647 + +uint8max = 255 +uint16max = 65535 +uint32max = 4294967295 + +plotlyjsShortTypes = { + "int8": "i1", + "uint8": "u1", + "int16": "i2", + "uint16": "u2", + "int32": "i4", + "uint32": "u4", + "float32": "f4", + "float64": "f8", +} + + +def to_typed_array_spec(v): + """ + Convert numpy array to plotly.js typed array spec + If not possible return the original value + """ + v = copy_to_readonly_numpy_array(v) + + np = get_module("numpy", should_load=False) + if not isinstance(v, np.ndarray): + return v + + dtype = str(v.dtype) + + # convert default Big Ints until we could support them in plotly.js + if dtype == "int64": + max = v.max() + min = v.min() + if max <= int8max and min >= int8min: + v = v.astype("int8") + elif max <= int16max and min >= int16min: + v = v.astype("int16") + elif max <= int32max and min >= int32min: + v = v.astype("int32") + else: + return v + + elif dtype == "uint64": + max = v.max() + min = v.min() + if max <= uint8max and min >= 0: + v = v.astype("uint8") + elif max <= uint16max and min >= 0: + v = v.astype("uint16") + elif max <= uint32max and min >= 0: + v = v.astype("uint32") + else: + return v + + dtype = str(v.dtype) + + if dtype in plotlyjsShortTypes: + arrObj = { + "dtype": plotlyjsShortTypes[dtype], + "bdata": base64.b64encode(v).decode("ascii"), + } + + if v.ndim > 1: + arrObj["shape"] = str(v.shape)[1:-1] + + return arrObj + + return v + + +def is_skipped_key(key): + """ + Return whether any keys in the parent hierarchy are in the list of keys that + are skipped for conversion to the typed array spec + """ + skipped_keys = ["geojson", "layer", "range"] + return any(skipped_key in key for skipped_key in skipped_keys) + + +def convert_to_base64(obj): + if isinstance(obj, dict): + for key, value in obj.items(): + if is_skipped_key(key): + continue + elif is_homogeneous_array(value): + obj[key] = to_typed_array_spec(value) + else: + convert_to_base64(value) + elif isinstance(obj, list) or isinstance(obj, tuple): + for i, value in enumerate(obj): + convert_to_base64(value) def cumsum(x): diff --git a/packages/python/plotly/plotly/basedatatypes.py b/packages/python/plotly/plotly/basedatatypes.py index 21b4cb1f312..0fe26c91473 100644 --- a/packages/python/plotly/plotly/basedatatypes.py +++ b/packages/python/plotly/plotly/basedatatypes.py @@ -15,6 +15,7 @@ display_string_positions, chomp_empty_strings, find_closest_string, + convert_to_base64, ) from _plotly_utils.exceptions import PlotlyKeyError from .optional_imports import get_module @@ -3310,6 +3311,9 @@ def to_dict(self): if frames: res["frames"] = frames + # Add base64 conversion before sending to the front-end + convert_to_base64(res) + return res def to_plotly_json(self): diff --git a/packages/python/plotly/plotly/io/_utils.py b/packages/python/plotly/plotly/io/_utils.py index 658540ca71a..6e4fae66b8c 100644 --- a/packages/python/plotly/plotly/io/_utils.py +++ b/packages/python/plotly/plotly/io/_utils.py @@ -24,6 +24,7 @@ def validate_coerce_fig_to_dict(fig, validate): typ=type(fig), v=fig ) ) + return fig_dict