@@ -45,16 +45,34 @@ def preprocess_image(image):
4545
4646
4747def preprocess_mask (mask , scale_factor = 8 ):
48- mask = mask .convert ("L" )
49- w , h = mask .size
50- w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
51- mask = mask .resize ((w // scale_factor , h // scale_factor ), resample = PIL_INTERPOLATION ["nearest" ])
52- mask = np .array (mask ).astype (np .float32 ) / 255.0
53- mask = np .tile (mask , (4 , 1 , 1 ))
54- mask = mask [None ].transpose (0 , 1 , 2 , 3 ) # what does this step do?
55- mask = 1 - mask # repaint white, keep black
56- mask = torch .from_numpy (mask )
57- return mask
48+
49+ if not isinstance (mask , torch .FloatTensor ):
50+ mask = mask .convert ("L" )
51+ w , h = mask .size
52+ w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
53+ mask = mask .resize ((w // scale_factor , h // scale_factor ), resample = PIL_INTERPOLATION ["nearest" ])
54+ mask = np .array (mask ).astype (np .float32 ) / 255.0
55+ mask = np .tile (mask , (4 , 1 , 1 ))
56+ mask = mask [None ].transpose (0 , 1 , 2 , 3 ) # what does this step do?
57+ mask = 1 - mask # repaint white, keep black
58+ mask = torch .from_numpy (mask )
59+ return mask
60+
61+ else :
62+ valid_mask_channel_sizes = [1 , 3 ]
63+ # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W)
64+ if mask .shape [3 ] in valid_mask_channel_sizes :
65+ mask = mask .permute (0 , 3 , 1 , 2 )
66+ elif mask .shape [1 ] not in valid_mask_channel_sizes :
67+ raise ValueError (
68+ f"Mask channel dimension of size in { valid_mask_channel_sizes } should be second or fourth dimension, but received mask of shape { tuple (mask .shape )} "
69+ )
70+ # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
71+ mask = mask .mean (dim = 1 , keepdim = True )
72+ h , w = mask .shape [- 2 :]
73+ h , w = map (lambda x : x - x % 32 , (h , w )) # resize to integer multiple of 32
74+ mask = torch .nn .functional .interpolate (mask , (h // scale_factor , w // scale_factor ))
75+ return mask
5876
5977
6078class StableDiffusionInpaintPipelineLegacy (DiffusionPipeline ):
@@ -497,8 +515,8 @@ def __call__(
497515 mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
498516 `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
499517 replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
500- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
501- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)` .
518+ PIL image, it will be converted to a single channel (luminance) before use. If mask is a tensor, the
519+ expected shape should be either `(B, H, W, C)` or `(B, C, H, W)`, where C is 1 or 3 .
502520 strength (`float`, *optional*, defaults to 0.8):
503521 Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
504522 is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -585,8 +603,7 @@ def __call__(
585603 if not isinstance (image , torch .FloatTensor ):
586604 image = preprocess_image (image )
587605
588- if not isinstance (mask_image , torch .FloatTensor ):
589- mask_image = preprocess_mask (mask_image , self .vae_scale_factor )
606+ mask_image = preprocess_mask (mask_image , self .vae_scale_factor )
590607
591608 # 5. set timesteps
592609 self .scheduler .set_timesteps (num_inference_steps , device = device )
@@ -640,6 +657,9 @@ def __call__(
640657 if callback is not None and i % callback_steps == 0 :
641658 callback (i , t , latents )
642659
660+ # use original latents corresponding to unmasked portions of the image
661+ latents = (init_latents_orig * mask ) + (latents * (1 - mask ))
662+
643663 # 10. Post-processing
644664 image = self .decode_latents (latents )
645665
0 commit comments