Skip to content

Commit ac5aa1f

Browse files
committed
polished code and added doc example
1 parent 91c066e commit ac5aa1f

File tree

2 files changed

+43
-45
lines changed

2 files changed

+43
-45
lines changed

doc/python/imshow.md

+16-1
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ from skimage import io
415415
from skimage.data import image_fetcher
416416
path = image_fetcher.fetch('data/cells.tif')
417417
data = io.imread(path)
418-
img = data[25:40]
418+
img = data[20:45:2]
419419
fig = px.imshow(img, facet_col=0, binary_string=True, facet_col_wrap=5, height=700)
420420
fig.show()
421421
```
@@ -466,6 +466,21 @@ fig = px.imshow(ds, animation_frame='time', zmin=220, zmax=300, color_continuous
466466
fig.show()
467467
```
468468

469+
### Combining animations and facets
470+
471+
It is possible to view 4-dimensional datasets (for example, 3-D images evolving with time) using a combination of `animation_frame` and `facet_col`.
472+
473+
```python
474+
import plotly.express as px
475+
from skimage import io
476+
from skimage.data import image_fetcher
477+
path = image_fetcher.fetch('data/cells.tif')
478+
data = io.imread(path)
479+
data = data.reshape((15, 4, 256, 256))[5:]
480+
fig = px.imshow(data, animation_frame=0, facet_col=1, binary_string=True)
481+
fig.show()
482+
```
483+
469484
#### Reference
470485

471486
See https://plotly.com/python/reference/image/ for more information and chart attribute options!

packages/python/plotly/plotly/express/_imshow.py

+27-44
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,9 @@ def imshow(
314314
animation_frame = img.dims.index(animation_frame)
315315
nslices_animation = img.shape[animation_frame]
316316
animation_slices = range(nslices_animation)
317-
slice_through = (facet_col is not None) or (animation_frame is not None)
318-
double_slice_through = (facet_col is not None) and (animation_frame is not None)
317+
slice_dimensions = (facet_col is not None) + (
318+
animation_frame is not None
319+
) # 0, 1, or 2
319320
facet_label = None
320321
animation_label = None
321322
# ----- Define x and y, set labels if img is an xarray -------------------
@@ -344,10 +345,10 @@ def imshow(
344345
labels["x"] = x_label
345346
if labels.get("y", None) is None:
346347
labels["y"] = y_label
347-
if labels.get("animation_slice", None) is None:
348-
labels["animation_slice"] = animation_label
349-
if labels.get("facet_slice", None) is None:
350-
labels["facet_slice"] = facet_label
348+
if labels.get("animation", None) is None:
349+
labels["animation"] = animation_label
350+
if labels.get("facet", None) is None:
351+
labels["facet"] = facet_label
351352
if labels.get("color", None) is None:
352353
labels["color"] = xarray.plot.utils.label_from_attrs(img)
353354
labels["color"] = labels["color"].replace("\n", "<br>")
@@ -382,32 +383,27 @@ def imshow(
382383

383384
# --------------- Starting from here img is always a numpy array --------
384385
img = np.asanyarray(img)
386+
# Reshape array so that animation dimension comes first, then facets, then images
385387
if facet_col is not None:
386388
img = np.moveaxis(img, facet_col, 0)
387-
print(img.shape)
388389
if animation_frame is not None and animation_frame < facet_col:
389390
animation_frame += 1
390391
facet_col = True
391392
if animation_frame is not None:
392393
img = np.moveaxis(img, animation_frame, 0)
393-
print(img.shape)
394394
animation_frame = True
395-
args["animation_frame"] = ( # TODO
396-
"slice" if labels.get("slice") is None else labels["slice"]
395+
args["animation_frame"] = (
396+
"slice" if labels.get("animation") is None else labels["animation"]
397397
)
398398
iterables = ()
399-
if slice_through:
400-
if animation_frame is not None:
401-
iterables += (range(nslices_animation),)
402-
if facet_col is not None:
403-
iterables += (range(nslices_facet),)
399+
if animation_frame is not None:
400+
iterables += (range(nslices_animation),)
401+
if facet_col is not None:
402+
iterables += (range(nslices_facet),)
404403

405404
# Default behaviour of binary_string: True for RGB images, False for 2D
406405
if binary_string is None:
407-
if slice_through:
408-
binary_string = img.ndim >= 4 and not is_dataframe
409-
else:
410-
binary_string = img.ndim >= 3 and not is_dataframe
406+
binary_string = img.ndim >= (3 + slice_dimensions) and not is_dataframe
411407

412408
# Cast bools to uint8 (also one byte)
413409
if img.dtype == np.bool:
@@ -419,11 +415,7 @@ def imshow(
419415

420416
# -------- Contrast rescaling: either minmax or infer ------------------
421417
if contrast_rescaling is None:
422-
contrast_rescaling = (
423-
"minmax"
424-
if (img.ndim == 2 or (img.ndim == 3 and slice_through))
425-
else "infer"
426-
)
418+
contrast_rescaling = "minmax" if img.ndim == (2 + slice_dimensions) else "infer"
427419

428420
# We try to set zmin and zmax only if necessary, because traces have good defaults
429421
if contrast_rescaling == "minmax":
@@ -439,19 +431,15 @@ def imshow(
439431
if zmin is None and zmax is not None:
440432
zmin = 0
441433

442-
# For 2d data, use Heatmap trace, unless binary_string is True
443-
if (
444-
img.ndim == 2
445-
or (img.ndim == 3 and slice_through)
446-
or (img.ndim == 4 and double_slice_through)
447-
) and not binary_string:
448-
y_index = 1 if slice_through else 0
434+
# For 2d data, use Heatmap trace, unless binary_string is True
435+
if img.ndim == 2 + slice_dimensions and not binary_string:
436+
y_index = slice_dimensions
449437
if y is not None and img.shape[y_index] != len(y):
450438
raise ValueError(
451439
"The length of the y vector must match the length of the first "
452440
+ "dimension of the img matrix."
453441
)
454-
x_index = 2 if slice_through else 1
442+
x_index = slice_dimensions + 1
455443
if x is not None and img.shape[x_index] != len(x):
456444
raise ValueError(
457445
"The length of the x vector must match the length of the second "
@@ -480,7 +468,8 @@ def imshow(
480468

481469
# For 2D+RGB data, use Image trace
482470
elif (
483-
img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_through and binary_string)
471+
img.ndim >= 3
472+
and (img.shape[-1] in [3, 4] or slice_dimensions and binary_string)
484473
) or (img.ndim == 2 and binary_string):
485474
rescale_image = True # to check whether image has been modified
486475
if zmin is not None and zmax is not None:
@@ -492,11 +481,7 @@ def imshow(
492481
if zmin is None and zmax is None: # no rescaling, faster
493482
img_rescaled = img
494483
rescale_image = False
495-
elif (
496-
img.ndim == 2
497-
or (img.ndim == 3 and slice_through)
498-
or (img.ndim == 4 and double_slice_through)
499-
):
484+
elif img.ndim == 2 + slice_dimensions: # single-channel image
500485
img_rescaled = rescale_intensity(
501486
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
502487
)
@@ -547,9 +532,7 @@ def imshow(
547532
# Now build figure
548533
col_labels = []
549534
if facet_col is not None:
550-
slice_label = "slice" if labels.get("slice") is None else labels["slice"]
551-
if facet_slices is None:
552-
facet_slices = range(nslices_facet)
535+
slice_label = "slice" if labels.get("facet") is None else labels["facet"]
553536
col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices]
554537
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
555538
layout_patch = dict()
@@ -566,12 +549,12 @@ def imshow(
566549
if (facet_col and index < nrows * ncols) or index == 0:
567550
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
568551
if animation_frame is not None:
569-
for i in range(nslices_animation):
552+
for i, index in zip(range(nslices_animation), animation_slices):
570553
frame_list.append(
571554
dict(
572555
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
573556
layout=layout,
574-
name=str(i),
557+
name=str(index),
575558
)
576559
)
577560
if animation_frame:
@@ -607,5 +590,5 @@ def imshow(
607590
if labels["y"]:
608591
fig.update_yaxes(title_text=labels["y"])
609592
configure_animation_controls(args, go.Image, fig)
610-
# fig.update_layout(template=args["template"], overwrite=True)
593+
fig.update_layout(template=args["template"], overwrite=True)
611594
return fig

0 commit comments

Comments
 (0)