Skip to content

Commit fb38bb1

Browse files
authored
Support grayscale images in numpy_to_pil (huggingface#1025)
1 parent de00c63 commit fb38bb1

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

src/diffusers/pipeline_flax_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,11 @@ def numpy_to_pil(images):
444444
if images.ndim == 3:
445445
images = images[None, ...]
446446
images = (images * 255).round().astype("uint8")
447-
pil_images = [Image.fromarray(image) for image in images]
447+
if images.shape[-1] == 1:
448+
# special case for grayscale (single channel) images
449+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
450+
else:
451+
pil_images = [Image.fromarray(image) for image in images]
448452

449453
return pil_images
450454

src/diffusers/pipeline_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,11 @@ def numpy_to_pil(images):
625625
if images.ndim == 3:
626626
images = images[None, ...]
627627
images = (images * 255).round().astype("uint8")
628-
pil_images = [Image.fromarray(image) for image in images]
628+
if images.shape[-1] == 1:
629+
# special case for grayscale (single channel) images
630+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
631+
else:
632+
pil_images = [Image.fromarray(image) for image in images]
629633

630634
return pil_images
631635

0 commit comments

Comments
 (0)