2121import PIL .Image
2222import torch
2323import torch .nn .functional as F
24- from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer
24+ from transformers import CLIPImageProcessor , CLIPTextModel , CLIPTokenizer , CLIPVisionModelWithProjection
2525
2626from ...image_processor import PipelineImageInput , VaeImageProcessor
27- from ...loaders import FromSingleFileMixin , LoraLoaderMixin , TextualInversionLoaderMixin
27+ from ...loaders import FromSingleFileMixin , IPAdapterMixin , LoraLoaderMixin , TextualInversionLoaderMixin
2828from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
2929from ...models .lora import adjust_lora_scale_text_encoder
3030from ...schedulers import KarrasDiffusionSchedulers
@@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
241241
242242
243243class StableDiffusionControlNetInpaintPipeline (
244- DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , FromSingleFileMixin
244+ DiffusionPipeline , TextualInversionLoaderMixin , LoraLoaderMixin , IPAdapterMixin , FromSingleFileMixin
245245):
246246 r"""
247247 Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
@@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline(
251251
252252 The pipeline also inherits the following loading methods:
253253 - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
254+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
254255
255256 <Tip>
256257
@@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline(
288289 """
289290
290291 model_cpu_offload_seq = "text_encoder->unet->vae"
291- _optional_components = ["safety_checker" , "feature_extractor" ]
292+ _optional_components = ["safety_checker" , "feature_extractor" , "image_encoder" ]
292293 _exclude_from_cpu_offload = ["safety_checker" ]
293294 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
294295
@@ -302,6 +303,7 @@ def __init__(
302303 scheduler : KarrasDiffusionSchedulers ,
303304 safety_checker : StableDiffusionSafetyChecker ,
304305 feature_extractor : CLIPImageProcessor ,
306+ image_encoder : CLIPVisionModelWithProjection = None ,
305307 requires_safety_checker : bool = True ,
306308 ):
307309 super ().__init__ ()
@@ -334,6 +336,7 @@ def __init__(
334336 scheduler = scheduler ,
335337 safety_checker = safety_checker ,
336338 feature_extractor = feature_extractor ,
339+ image_encoder = image_encoder ,
337340 )
338341 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
339342 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
@@ -593,6 +596,20 @@ def encode_prompt(
593596
594597 return prompt_embeds , negative_prompt_embeds
595598
599+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
600+ def encode_image (self , image , device , num_images_per_prompt ):
601+ dtype = next (self .image_encoder .parameters ()).dtype
602+
603+ if not isinstance (image , torch .Tensor ):
604+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
605+
606+ image = image .to (device = device , dtype = dtype )
607+ image_embeds = self .image_encoder (image ).image_embeds
608+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
609+
610+ uncond_image_embeds = torch .zeros_like (image_embeds )
611+ return image_embeds , uncond_image_embeds
612+
596613 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
597614 def run_safety_checker (self , image , device , dtype ):
598615 if self .safety_checker is None :
@@ -1053,6 +1070,7 @@ def __call__(
10531070 latents : Optional [torch .FloatTensor ] = None ,
10541071 prompt_embeds : Optional [torch .FloatTensor ] = None ,
10551072 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
1073+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
10561074 output_type : Optional [str ] = "pil" ,
10571075 return_dict : bool = True ,
10581076 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -1131,6 +1149,7 @@ def __call__(
11311149 negative_prompt_embeds (`torch.FloatTensor`, *optional*):
11321150 Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
11331151 not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1152+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
11341153 output_type (`str`, *optional*, defaults to `"pil"`):
11351154 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
11361155 return_dict (`bool`, *optional*, defaults to `True`):
@@ -1264,6 +1283,11 @@ def __call__(
12641283 if self .do_classifier_free_guidance :
12651284 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
12661285
1286+ if ip_adapter_image is not None :
1287+ image_embeds , negative_image_embeds = self .encode_image (ip_adapter_image , device , num_images_per_prompt )
1288+ if self .do_classifier_free_guidance :
1289+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1290+
12671291 # 4. Prepare image
12681292 if isinstance (controlnet , ControlNetModel ):
12691293 control_image = self .prepare_control_image (
@@ -1299,7 +1323,7 @@ def __call__(
12991323 else :
13001324 assert False
13011325
1302- # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1326+ # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
13031327 init_image = self .image_processor .preprocess (image , height = height , width = width )
13041328 init_image = init_image .to (dtype = torch .float32 )
13051329
@@ -1360,7 +1384,10 @@ def __call__(
13601384 # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
13611385 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
13621386
1363- # 7.1 Create tensor stating which controlnets to keep
1387+ # 7.1 Add image embeds for IP-Adapter
1388+ added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
1389+
1390+ # 7.2 Create tensor stating which controlnets to keep
13641391 controlnet_keep = []
13651392 for i in range (len (timesteps )):
13661393 keeps = [
@@ -1423,6 +1450,7 @@ def __call__(
14231450 cross_attention_kwargs = self .cross_attention_kwargs ,
14241451 down_block_additional_residuals = down_block_res_samples ,
14251452 mid_block_additional_residual = mid_block_res_sample ,
1453+ added_cond_kwargs = added_cond_kwargs ,
14261454 return_dict = False ,
14271455 )[0 ]
14281456
0 commit comments