2929 UNet2DModel ,
3030 VQModel ,
3131)
32+
3233from diffusers .utils import floats_tensor , load_image , load_numpy , slow , torch_device
34+ from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
35+
3336from diffusers .utils .testing_utils import require_torch_gpu
3437from PIL import Image
3538from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
@@ -506,3 +509,171 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
506509 mem_bytes = torch .cuda .max_memory_allocated ()
507510 # make sure that less than 2.2 GB is allocated
508511 assert mem_bytes < 2.2 * 10 ** 9
512+
513+ class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests (unittest .TestCase ):
514+ def test_pil_inputs (self ):
515+ im = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
516+ im = Image .fromarray (im )
517+ mask = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
518+ mask = Image .fromarray ((mask * 255 ).astype (np .uint8 ))
519+
520+ t_mask , t_masked = prepare_mask_and_masked_image (im , mask )
521+
522+ self .assertTrue (isinstance (t_mask , torch .Tensor ))
523+ self .assertTrue (isinstance (t_masked , torch .Tensor ))
524+
525+ self .assertEqual (t_mask .ndim , 4 )
526+ self .assertEqual (t_masked .ndim , 4 )
527+
528+ self .assertEqual (t_mask .shape , (1 , 1 , 32 , 32 ))
529+ self .assertEqual (t_masked .shape , (1 , 3 , 32 , 32 ))
530+
531+ self .assertTrue (t_mask .dtype == torch .float32 )
532+ self .assertTrue (t_masked .dtype == torch .float32 )
533+
534+ self .assertTrue (t_mask .min () >= 0.0 )
535+ self .assertTrue (t_mask .max () <= 1.0 )
536+ self .assertTrue (t_masked .min () >= - 1.0 )
537+ self .assertTrue (t_masked .min () <= 1.0 )
538+
539+ self .assertTrue (t_mask .sum () > 0.0 )
540+
541+ def test_np_inputs (self ):
542+ im_np = np .random .randint (0 , 255 , (32 , 32 , 3 ), dtype = np .uint8 )
543+ im_pil = Image .fromarray (im_np )
544+ mask_np = np .random .randint (0 , 255 , (32 , 32 ), dtype = np .uint8 ) > 127.5
545+ mask_pil = Image .fromarray ((mask_np * 255 ).astype (np .uint8 ))
546+
547+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
548+ t_mask_pil , t_masked_pil = prepare_mask_and_masked_image (im_pil , mask_pil )
549+
550+ self .assertTrue ((t_mask_np == t_mask_pil ).all ())
551+ self .assertTrue ((t_masked_np == t_masked_pil ).all ())
552+
553+ def test_torch_3D_2D_inputs (self ):
554+ im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
555+ mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
556+ im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
557+ mask_np = mask_tensor .numpy ()
558+
559+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
560+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
561+
562+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
563+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
564+
565+ def test_torch_3D_3D_inputs (self ):
566+ im_tensor = torch .randint (0 , 255 , (3 , 32 , 32 ), dtype = torch .uint8 )
567+ mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
568+ im_np = im_tensor .numpy ().transpose (1 , 2 , 0 )
569+ mask_np = mask_tensor .numpy ()[0 ]
570+
571+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
572+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
573+
574+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
575+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
576+
577+ def test_torch_4D_2D_inputs (self ):
578+ im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
579+ mask_tensor = torch .randint (0 , 255 , (32 , 32 ), dtype = torch .uint8 ) > 127.5
580+ im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
581+ mask_np = mask_tensor .numpy ()
582+
583+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
584+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
585+
586+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
587+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
588+
589+ def test_torch_4D_3D_inputs (self ):
590+ im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
591+ mask_tensor = torch .randint (0 , 255 , (1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
592+ im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
593+ mask_np = mask_tensor .numpy ()[0 ]
594+
595+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
596+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
597+
598+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
599+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
600+
601+ def test_torch_4D_4D_inputs (self ):
602+ im_tensor = torch .randint (0 , 255 , (1 , 3 , 32 , 32 ), dtype = torch .uint8 )
603+ mask_tensor = torch .randint (0 , 255 , (1 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
604+ im_np = im_tensor .numpy ()[0 ].transpose (1 , 2 , 0 )
605+ mask_np = mask_tensor .numpy ()[0 ][0 ]
606+
607+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
608+ t_mask_np , t_masked_np = prepare_mask_and_masked_image (im_np , mask_np )
609+
610+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
611+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
612+
613+ def test_torch_batch_4D_3D (self ):
614+ im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
615+ mask_tensor = torch .randint (0 , 255 , (2 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
616+
617+ im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
618+ mask_nps = [mask .numpy () for mask in mask_tensor ]
619+
620+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
621+ nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
622+ t_mask_np = torch .cat ([n [0 ] for n in nps ])
623+ t_masked_np = torch .cat ([n [1 ] for n in nps ])
624+
625+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
626+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
627+
628+ def test_torch_batch_4D_4D (self ):
629+ im_tensor = torch .randint (0 , 255 , (2 , 3 , 32 , 32 ), dtype = torch .uint8 )
630+ mask_tensor = torch .randint (0 , 255 , (2 , 1 , 32 , 32 ), dtype = torch .uint8 ) > 127.5
631+
632+ im_nps = [im .numpy ().transpose (1 , 2 , 0 ) for im in im_tensor ]
633+ mask_nps = [mask .numpy ()[0 ] for mask in mask_tensor ]
634+
635+ t_mask_tensor , t_masked_tensor = prepare_mask_and_masked_image (im_tensor / 127.5 - 1 , mask_tensor )
636+ nps = [prepare_mask_and_masked_image (i , m ) for i , m in zip (im_nps , mask_nps )]
637+ t_mask_np = torch .cat ([n [0 ] for n in nps ])
638+ t_masked_np = torch .cat ([n [1 ] for n in nps ])
639+
640+ self .assertTrue ((t_mask_tensor == t_mask_np ).all ())
641+ self .assertTrue ((t_masked_tensor == t_masked_np ).all ())
642+
643+ def test_shape_mismatch (self ):
644+ # test height and width
645+ with self .assertRaises (AssertionError ):
646+ prepare_mask_and_masked_image (torch .randn (3 , 32 , 32 ), torch .randn (64 , 64 ))
647+ # test batch dim
648+ with self .assertRaises (AssertionError ):
649+ prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 64 , 64 ))
650+ # test batch dim
651+ with self .assertRaises (AssertionError ):
652+ prepare_mask_and_masked_image (torch .randn (2 , 3 , 32 , 32 ), torch .randn (4 , 1 , 64 , 64 ))
653+
654+ def test_type_mismatch (self ):
655+ # test tensors-only
656+ with self .assertRaises (TypeError ):
657+ prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .rand (3 , 32 , 32 ).numpy ())
658+ # test tensors-only
659+ with self .assertRaises (TypeError ):
660+ prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ).numpy (), torch .rand (3 , 32 , 32 ))
661+
662+ def test_channels_first (self ):
663+ # test channels first for 3D tensors
664+ with self .assertRaises (AssertionError ):
665+ prepare_mask_and_masked_image (torch .rand (32 , 32 , 3 ), torch .rand (3 , 32 , 32 ))
666+
667+ def test_tensor_range (self ):
668+ # test im <= 1
669+ with self .assertRaises (ValueError ):
670+ prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * 2 , torch .rand (32 , 32 ))
671+ # test im >= -1
672+ with self .assertRaises (ValueError ):
673+ prepare_mask_and_masked_image (torch .ones (3 , 32 , 32 ) * (- 2 ), torch .rand (32 , 32 ))
674+ # test mask <= 1
675+ with self .assertRaises (ValueError ):
676+ prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * 2 )
677+ # test mask >= 0
678+ with self .assertRaises (ValueError ):
679+ prepare_mask_and_masked_image (torch .rand (3 , 32 , 32 ), torch .ones (32 , 32 ) * - 1 )
0 commit comments