@@ -683,16 +683,19 @@ def check_inputs(
683683 self ,
684684 prompt ,
685685 image ,
686+ mask_image ,
686687 height ,
687688 width ,
688689 callback_steps ,
690+ output_type ,
689691 negative_prompt = None ,
690692 prompt_embeds = None ,
691693 negative_prompt_embeds = None ,
692694 controlnet_conditioning_scale = 1.0 ,
693695 control_guidance_start = 0.0 ,
694696 control_guidance_end = 1.0 ,
695697 callback_on_step_end_tensor_inputs = None ,
698+ padding_mask_crop = None ,
696699 ):
697700 if height is not None and height % 8 != 0 or width is not None and width % 8 != 0 :
698701 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -736,6 +739,19 @@ def check_inputs(
736739 f" { negative_prompt_embeds .shape } ."
737740 )
738741
742+ if padding_mask_crop is not None :
743+ if not isinstance (image , PIL .Image .Image ):
744+ raise ValueError (
745+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" { type (image )} ."
746+ )
747+ if not isinstance (mask_image , PIL .Image .Image ):
748+ raise ValueError (
749+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
750+ f" { type (mask_image )} ."
751+ )
752+ if output_type != "pil" :
753+ raise ValueError (f"The output type should be PIL when inpainting mask crop, but is" f" { output_type } ." )
754+
739755 # `prompt` needs more sophisticated handling when there are multiple
740756 # conditionings.
741757 if isinstance (self .controlnet , MultiControlNetModel ):
@@ -862,7 +878,6 @@ def check_image(self, image, prompt, prompt_embeds):
862878 f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
863879 )
864880
865- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
866881 def prepare_control_image (
867882 self ,
868883 image ,
@@ -872,10 +887,14 @@ def prepare_control_image(
872887 num_images_per_prompt ,
873888 device ,
874889 dtype ,
890+ crops_coords ,
891+ resize_mode ,
875892 do_classifier_free_guidance = False ,
876893 guess_mode = False ,
877894 ):
878- image = self .control_image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
895+ image = self .control_image_processor .preprocess (
896+ image , height = height , width = width , crops_coords = crops_coords , resize_mode = resize_mode
897+ ).to (dtype = torch .float32 )
879898 image_batch_size = image .shape [0 ]
880899
881900 if image_batch_size == 1 :
@@ -1074,6 +1093,7 @@ def __call__(
10741093 control_image : PipelineImageInput = None ,
10751094 height : Optional [int ] = None ,
10761095 width : Optional [int ] = None ,
1096+ padding_mask_crop : Optional [int ] = None ,
10771097 strength : float = 1.0 ,
10781098 num_inference_steps : int = 50 ,
10791099 guidance_scale : float = 7.5 ,
@@ -1130,6 +1150,12 @@ def __call__(
11301150 The height in pixels of the generated image.
11311151 width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
11321152 The width in pixels of the generated image.
1153+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1154+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
1155+ `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
1156+ contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
1157+ the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
1158+ and contain information inreleant for inpainging, such as background.
11331159 strength (`float`, *optional*, defaults to 1.0):
11341160 Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
11351161 starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1240,16 +1266,19 @@ def __call__(
12401266 self .check_inputs (
12411267 prompt ,
12421268 control_image ,
1269+ mask_image ,
12431270 height ,
12441271 width ,
12451272 callback_steps ,
1273+ output_type ,
12461274 negative_prompt ,
12471275 prompt_embeds ,
12481276 negative_prompt_embeds ,
12491277 controlnet_conditioning_scale ,
12501278 control_guidance_start ,
12511279 control_guidance_end ,
12521280 callback_on_step_end_tensor_inputs ,
1281+ padding_mask_crop ,
12531282 )
12541283
12551284 self ._guidance_scale = guidance_scale
@@ -1264,6 +1293,14 @@ def __call__(
12641293 else :
12651294 batch_size = prompt_embeds .shape [0 ]
12661295
1296+ if padding_mask_crop is not None :
1297+ height , width = self .image_processor .get_default_height_width (image , height , width )
1298+ crops_coords = self .mask_processor .get_crop_region (mask_image , width , height , pad = padding_mask_crop )
1299+ resize_mode = "fill"
1300+ else :
1301+ crops_coords = None
1302+ resize_mode = "default"
1303+
12671304 device = self ._execution_device
12681305
12691306 if isinstance (controlnet , MultiControlNetModel ) and isinstance (controlnet_conditioning_scale , float ):
@@ -1315,6 +1352,8 @@ def __call__(
13151352 num_images_per_prompt = num_images_per_prompt ,
13161353 device = device ,
13171354 dtype = controlnet .dtype ,
1355+ crops_coords = crops_coords ,
1356+ resize_mode = resize_mode ,
13181357 do_classifier_free_guidance = self .do_classifier_free_guidance ,
13191358 guess_mode = guess_mode ,
13201359 )
@@ -1330,6 +1369,8 @@ def __call__(
13301369 num_images_per_prompt = num_images_per_prompt ,
13311370 device = device ,
13321371 dtype = controlnet .dtype ,
1372+ crops_coords = crops_coords ,
1373+ resize_mode = resize_mode ,
13331374 do_classifier_free_guidance = self .do_classifier_free_guidance ,
13341375 guess_mode = guess_mode ,
13351376 )
@@ -1341,10 +1382,15 @@ def __call__(
13411382 assert False
13421383
13431384 # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
1344- init_image = self .image_processor .preprocess (image , height = height , width = width )
1385+ original_image = image
1386+ init_image = self .image_processor .preprocess (
1387+ image , height = height , width = width , crops_coords = crops_coords , resize_mode = resize_mode
1388+ )
13451389 init_image = init_image .to (dtype = torch .float32 )
13461390
1347- mask = self .mask_processor .preprocess (mask_image , height = height , width = width )
1391+ mask = self .mask_processor .preprocess (
1392+ mask_image , height = height , width = width , resize_mode = resize_mode , crops_coords = crops_coords
1393+ )
13481394
13491395 masked_image = init_image * (mask < 0.5 )
13501396 _ , _ , height , width = init_image .shape
@@ -1534,6 +1580,9 @@ def __call__(
15341580
15351581 image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
15361582
1583+ if padding_mask_crop is not None :
1584+ image = [self .image_processor .apply_overlay (mask_image , original_image , i , crops_coords ) for i in image ]
1585+
15371586 # Offload all models
15381587 self .maybe_free_model_hooks ()
15391588
0 commit comments