@@ -314,8 +314,9 @@ def imshow(
314
314
animation_frame = img .dims .index (animation_frame )
315
315
nslices_animation = img .shape [animation_frame ]
316
316
animation_slices = range (nslices_animation )
317
- slice_through = (facet_col is not None ) or (animation_frame is not None )
318
- double_slice_through = (facet_col is not None ) and (animation_frame is not None )
317
+ slice_dimensions = (facet_col is not None ) + (
318
+ animation_frame is not None
319
+ ) # 0, 1, or 2
319
320
facet_label = None
320
321
animation_label = None
321
322
# ----- Define x and y, set labels if img is an xarray -------------------
@@ -344,10 +345,10 @@ def imshow(
344
345
labels ["x" ] = x_label
345
346
if labels .get ("y" , None ) is None :
346
347
labels ["y" ] = y_label
347
- if labels .get ("animation_slice " , None ) is None :
348
- labels ["animation_slice " ] = animation_label
349
- if labels .get ("facet_slice " , None ) is None :
350
- labels ["facet_slice " ] = facet_label
348
+ if labels .get ("animation " , None ) is None :
349
+ labels ["animation " ] = animation_label
350
+ if labels .get ("facet " , None ) is None :
351
+ labels ["facet " ] = facet_label
351
352
if labels .get ("color" , None ) is None :
352
353
labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
353
354
labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -382,32 +383,27 @@ def imshow(
382
383
383
384
# --------------- Starting from here img is always a numpy array --------
384
385
img = np .asanyarray (img )
386
+ # Reshape array so that animation dimension comes first, then facets, then images
385
387
if facet_col is not None :
386
388
img = np .moveaxis (img , facet_col , 0 )
387
- print (img .shape )
388
389
if animation_frame is not None and animation_frame < facet_col :
389
390
animation_frame += 1
390
391
facet_col = True
391
392
if animation_frame is not None :
392
393
img = np .moveaxis (img , animation_frame , 0 )
393
- print (img .shape )
394
394
animation_frame = True
395
- args ["animation_frame" ] = ( # TODO
396
- "slice" if labels .get ("slice " ) is None else labels ["slice " ]
395
+ args ["animation_frame" ] = (
396
+ "slice" if labels .get ("animation " ) is None else labels ["animation " ]
397
397
)
398
398
iterables = ()
399
- if slice_through :
400
- if animation_frame is not None :
401
- iterables += (range (nslices_animation ),)
402
- if facet_col is not None :
403
- iterables += (range (nslices_facet ),)
399
+ if animation_frame is not None :
400
+ iterables += (range (nslices_animation ),)
401
+ if facet_col is not None :
402
+ iterables += (range (nslices_facet ),)
404
403
405
404
# Default behaviour of binary_string: True for RGB images, False for 2D
406
405
if binary_string is None :
407
- if slice_through :
408
- binary_string = img .ndim >= 4 and not is_dataframe
409
- else :
410
- binary_string = img .ndim >= 3 and not is_dataframe
406
+ binary_string = img .ndim >= (3 + slice_dimensions ) and not is_dataframe
411
407
412
408
# Cast bools to uint8 (also one byte)
413
409
if img .dtype == np .bool :
@@ -419,11 +415,7 @@ def imshow(
419
415
420
416
# -------- Contrast rescaling: either minmax or infer ------------------
421
417
if contrast_rescaling is None :
422
- contrast_rescaling = (
423
- "minmax"
424
- if (img .ndim == 2 or (img .ndim == 3 and slice_through ))
425
- else "infer"
426
- )
418
+ contrast_rescaling = "minmax" if img .ndim == (2 + slice_dimensions ) else "infer"
427
419
428
420
# We try to set zmin and zmax only if necessary, because traces have good defaults
429
421
if contrast_rescaling == "minmax" :
@@ -439,19 +431,15 @@ def imshow(
439
431
if zmin is None and zmax is not None :
440
432
zmin = 0
441
433
442
- # For 2d data, use Heatmap trace, unless binary_string is True
443
- if (
444
- img .ndim == 2
445
- or (img .ndim == 3 and slice_through )
446
- or (img .ndim == 4 and double_slice_through )
447
- ) and not binary_string :
448
- y_index = 1 if slice_through else 0
434
+ # For 2d data, use Heatmap trace, unless binary_string is True
435
+ if img .ndim == 2 + slice_dimensions and not binary_string :
436
+ y_index = slice_dimensions
449
437
if y is not None and img .shape [y_index ] != len (y ):
450
438
raise ValueError (
451
439
"The length of the y vector must match the length of the first "
452
440
+ "dimension of the img matrix."
453
441
)
454
- x_index = 2 if slice_through else 1
442
+ x_index = slice_dimensions + 1
455
443
if x is not None and img .shape [x_index ] != len (x ):
456
444
raise ValueError (
457
445
"The length of the x vector must match the length of the second "
@@ -480,7 +468,8 @@ def imshow(
480
468
481
469
# For 2D+RGB data, use Image trace
482
470
elif (
483
- img .ndim >= 3 and (img .shape [- 1 ] in [3 , 4 ] or slice_through and binary_string )
471
+ img .ndim >= 3
472
+ and (img .shape [- 1 ] in [3 , 4 ] or slice_dimensions and binary_string )
484
473
) or (img .ndim == 2 and binary_string ):
485
474
rescale_image = True # to check whether image has been modified
486
475
if zmin is not None and zmax is not None :
@@ -492,11 +481,7 @@ def imshow(
492
481
if zmin is None and zmax is None : # no rescaling, faster
493
482
img_rescaled = img
494
483
rescale_image = False
495
- elif (
496
- img .ndim == 2
497
- or (img .ndim == 3 and slice_through )
498
- or (img .ndim == 4 and double_slice_through )
499
- ):
484
+ elif img .ndim == 2 + slice_dimensions : # single-channel image
500
485
img_rescaled = rescale_intensity (
501
486
img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
502
487
)
@@ -547,9 +532,7 @@ def imshow(
547
532
# Now build figure
548
533
col_labels = []
549
534
if facet_col is not None :
550
- slice_label = "slice" if labels .get ("slice" ) is None else labels ["slice" ]
551
- if facet_slices is None :
552
- facet_slices = range (nslices_facet )
535
+ slice_label = "slice" if labels .get ("facet" ) is None else labels ["facet" ]
553
536
col_labels = ["%s = %d" % (slice_label , i ) for i in facet_slices ]
554
537
fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
555
538
layout_patch = dict ()
@@ -566,12 +549,12 @@ def imshow(
566
549
if (facet_col and index < nrows * ncols ) or index == 0 :
567
550
fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
568
551
if animation_frame is not None :
569
- for i in range (nslices_animation ):
552
+ for i , index in zip ( range (nslices_animation ), animation_slices ):
570
553
frame_list .append (
571
554
dict (
572
555
data = traces [nslices_facet * i : nslices_facet * (i + 1 )],
573
556
layout = layout ,
574
- name = str (i ),
557
+ name = str (index ),
575
558
)
576
559
)
577
560
if animation_frame :
@@ -607,5 +590,5 @@ def imshow(
607
590
if labels ["y" ]:
608
591
fig .update_yaxes (title_text = labels ["y" ])
609
592
configure_animation_controls (args , go .Image , fig )
610
- # fig.update_layout(template=args["template"], overwrite=True)
593
+ fig .update_layout (template = args ["template" ], overwrite = True )
611
594
return fig
0 commit comments