Skip to content

Commit 3bec90f

Browse files
Handle batches and Tensors in pipeline_stable_diffusion_inpaint.py:prepare_mask_and_masked_image (huggingface#1003)
* Handle batches and Tensors in `prepare_mask_and_masked_image` * `blackfy` upgrade `black` * handle mask as `np.array` * add docstring * revert `black` changes with smaller line length * missing ValueError in docstring * raise `TypeError` for image as tensor but not mask * typo in mask shape selection * check for batch dim * fix: wrong indentation * add tests Co-authored-by: Patrick von Platen <[email protected]>
1 parent eb2425b commit 3bec90f

File tree

2 files changed

+258
-10
lines changed

2 files changed

+258
-10
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,93 @@
3535

3636

3737
def prepare_mask_and_masked_image(image, mask):
38-
image = np.array(image.convert("RGB"))
39-
image = image[None].transpose(0, 3, 1, 2)
40-
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
41-
42-
mask = np.array(mask.convert("L"))
43-
mask = mask.astype(np.float32) / 255.0
44-
mask = mask[None, None]
45-
mask[mask < 0.5] = 0
46-
mask[mask >= 0.5] = 1
47-
mask = torch.from_numpy(mask)
38+
"""
39+
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline.
40+
This means that those inputs will be converted to ``torch.Tensor`` with
41+
shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for
42+
the ``image`` and ``1`` for the ``mask``.
43+
44+
The ``image`` will be converted to ``torch.float32`` and normalized to be in
45+
``[-1, 1]``. The ``mask`` will be binarized (``mask > 0.5``) and cast to
46+
``torch.float32`` too.
47+
48+
Args:
49+
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
50+
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array``
51+
or a ``channels x height x width`` ``torch.Tensor`` or a
52+
``batch x channels x height x width`` ``torch.Tensor``.
53+
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
54+
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or
55+
a ``1 x height x width`` ``torch.Tensor`` or a
56+
``batch x 1 x height x width`` ``torch.Tensor``.
57+
58+
59+
Raises:
60+
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range.
61+
ValueError: ``torch.Tensor`` mask should be in the ``[0, 1]`` range.
62+
ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
63+
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
64+
(ot the other way around).
65+
66+
Returns:
67+
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
68+
dimensions: ``batch x channels x height x width``.
69+
"""
70+
if isinstance(image, torch.Tensor):
71+
if not isinstance(mask, torch.Tensor):
72+
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
73+
74+
# Batch single image
75+
if image.ndim == 3:
76+
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
77+
image = image.unsqueeze(0)
78+
79+
# Batch and add channel dim for single mask
80+
if mask.ndim == 2:
81+
mask = mask.unsqueeze(0).unsqueeze(0)
82+
83+
# Batch single mask or add channel dim
84+
if mask.ndim == 3:
85+
# Single batched mask, no channel dim or single mask not batched but channel dim
86+
if mask.shape[0] == 1:
87+
mask = mask.unsqueeze(0)
88+
89+
# Batched masks no channel dim
90+
else:
91+
mask = mask.unsqueeze(1)
92+
93+
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
94+
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
95+
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
96+
97+
# Check image is in [-1, 1]
98+
if image.min() < -1 or image.max() > 1:
99+
raise ValueError("Image should be in [-1, 1] range")
100+
101+
# Check mask is in [0, 1]
102+
if mask.min() < 0 or mask.max() > 1:
103+
raise ValueError("Mask should be in [0, 1] range")
104+
105+
# Binarize mask
106+
mask[mask < 0.5] = 0
107+
mask[mask >= 0.5] = 1
108+
109+
# Image as float32
110+
image = image.to(dtype=torch.float32)
111+
elif isinstance(mask, torch.Tensor):
112+
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
113+
else:
114+
if isinstance(image, PIL.Image.Image):
115+
image = np.array(image.convert("RGB"))
116+
image = image[None].transpose(0, 3, 1, 2)
117+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
118+
if isinstance(mask, PIL.Image.Image):
119+
mask = np.array(mask.convert("L"))
120+
mask = mask.astype(np.float32) / 255.0
121+
mask = mask[None, None]
122+
mask[mask < 0.5] = 0
123+
mask[mask >= 0.5] = 1
124+
mask = torch.from_numpy(mask)
48125

49126
masked_image = image * (mask < 0.5)
50127

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
UNet2DModel,
3030
VQModel,
3131
)
32+
3233
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
34+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
35+
3336
from diffusers.utils.testing_utils import require_torch_gpu
3437
from PIL import Image
3538
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -506,3 +509,171 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
506509
mem_bytes = torch.cuda.max_memory_allocated()
507510
# make sure that less than 2.2 GB is allocated
508511
assert mem_bytes < 2.2 * 10**9
512+
513+
class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase):
514+
def test_pil_inputs(self):
515+
im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
516+
im = Image.fromarray(im)
517+
mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
518+
mask = Image.fromarray((mask * 255).astype(np.uint8))
519+
520+
t_mask, t_masked = prepare_mask_and_masked_image(im, mask)
521+
522+
self.assertTrue(isinstance(t_mask, torch.Tensor))
523+
self.assertTrue(isinstance(t_masked, torch.Tensor))
524+
525+
self.assertEqual(t_mask.ndim, 4)
526+
self.assertEqual(t_masked.ndim, 4)
527+
528+
self.assertEqual(t_mask.shape, (1, 1, 32, 32))
529+
self.assertEqual(t_masked.shape, (1, 3, 32, 32))
530+
531+
self.assertTrue(t_mask.dtype == torch.float32)
532+
self.assertTrue(t_masked.dtype == torch.float32)
533+
534+
self.assertTrue(t_mask.min() >= 0.0)
535+
self.assertTrue(t_mask.max() <= 1.0)
536+
self.assertTrue(t_masked.min() >= -1.0)
537+
self.assertTrue(t_masked.min() <= 1.0)
538+
539+
self.assertTrue(t_mask.sum() > 0.0)
540+
541+
def test_np_inputs(self):
542+
im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
543+
im_pil = Image.fromarray(im_np)
544+
mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5
545+
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
546+
547+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
548+
t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil)
549+
550+
self.assertTrue((t_mask_np == t_mask_pil).all())
551+
self.assertTrue((t_masked_np == t_masked_pil).all())
552+
553+
def test_torch_3D_2D_inputs(self):
554+
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
555+
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
556+
im_np = im_tensor.numpy().transpose(1, 2, 0)
557+
mask_np = mask_tensor.numpy()
558+
559+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
560+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
561+
562+
self.assertTrue((t_mask_tensor == t_mask_np).all())
563+
self.assertTrue((t_masked_tensor == t_masked_np).all())
564+
565+
def test_torch_3D_3D_inputs(self):
566+
im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8)
567+
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
568+
im_np = im_tensor.numpy().transpose(1, 2, 0)
569+
mask_np = mask_tensor.numpy()[0]
570+
571+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
572+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
573+
574+
self.assertTrue((t_mask_tensor == t_mask_np).all())
575+
self.assertTrue((t_masked_tensor == t_masked_np).all())
576+
577+
def test_torch_4D_2D_inputs(self):
578+
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
579+
mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5
580+
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
581+
mask_np = mask_tensor.numpy()
582+
583+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
584+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
585+
586+
self.assertTrue((t_mask_tensor == t_mask_np).all())
587+
self.assertTrue((t_masked_tensor == t_masked_np).all())
588+
589+
def test_torch_4D_3D_inputs(self):
590+
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
591+
mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5
592+
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
593+
mask_np = mask_tensor.numpy()[0]
594+
595+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
596+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
597+
598+
self.assertTrue((t_mask_tensor == t_mask_np).all())
599+
self.assertTrue((t_masked_tensor == t_masked_np).all())
600+
601+
def test_torch_4D_4D_inputs(self):
602+
im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8)
603+
mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5
604+
im_np = im_tensor.numpy()[0].transpose(1, 2, 0)
605+
mask_np = mask_tensor.numpy()[0][0]
606+
607+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
608+
t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np)
609+
610+
self.assertTrue((t_mask_tensor == t_mask_np).all())
611+
self.assertTrue((t_masked_tensor == t_masked_np).all())
612+
613+
def test_torch_batch_4D_3D(self):
614+
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
615+
mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5
616+
617+
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
618+
mask_nps = [mask.numpy() for mask in mask_tensor]
619+
620+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
621+
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
622+
t_mask_np = torch.cat([n[0] for n in nps])
623+
t_masked_np = torch.cat([n[1] for n in nps])
624+
625+
self.assertTrue((t_mask_tensor == t_mask_np).all())
626+
self.assertTrue((t_masked_tensor == t_masked_np).all())
627+
628+
def test_torch_batch_4D_4D(self):
629+
im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8)
630+
mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5
631+
632+
im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor]
633+
mask_nps = [mask.numpy()[0] for mask in mask_tensor]
634+
635+
t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor)
636+
nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)]
637+
t_mask_np = torch.cat([n[0] for n in nps])
638+
t_masked_np = torch.cat([n[1] for n in nps])
639+
640+
self.assertTrue((t_mask_tensor == t_mask_np).all())
641+
self.assertTrue((t_masked_tensor == t_masked_np).all())
642+
643+
def test_shape_mismatch(self):
644+
# test height and width
645+
with self.assertRaises(AssertionError):
646+
prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64))
647+
# test batch dim
648+
with self.assertRaises(AssertionError):
649+
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64))
650+
# test batch dim
651+
with self.assertRaises(AssertionError):
652+
prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64))
653+
654+
def test_type_mismatch(self):
655+
# test tensors-only
656+
with self.assertRaises(TypeError):
657+
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy())
658+
# test tensors-only
659+
with self.assertRaises(TypeError):
660+
prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32))
661+
662+
def test_channels_first(self):
663+
# test channels first for 3D tensors
664+
with self.assertRaises(AssertionError):
665+
prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32))
666+
667+
def test_tensor_range(self):
668+
# test im <= 1
669+
with self.assertRaises(ValueError):
670+
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32))
671+
# test im >= -1
672+
with self.assertRaises(ValueError):
673+
prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32))
674+
# test mask <= 1
675+
with self.assertRaises(ValueError):
676+
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2)
677+
# test mask >= 0
678+
with self.assertRaises(ValueError):
679+
prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1)

0 commit comments

Comments
 (0)