@@ -401,6 +401,40 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
401401
402402 return image_embeds , uncond_image_embeds
403403
404+ def prepare_ip_adapter_image_embeds (
405+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
406+ ):
407+ if ip_adapter_image_embeds is None :
408+ if not isinstance (ip_adapter_image , list ):
409+ ip_adapter_image = [ip_adapter_image ]
410+
411+ if len (ip_adapter_image ) != len (self .unet .encoder_hid_proj .image_projection_layers ):
412+ raise ValueError (
413+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got { len (ip_adapter_image )} images and { len (self .unet .encoder_hid_proj .image_projection_layers )} IP Adapters."
414+ )
415+
416+ image_embeds = []
417+ for single_ip_adapter_image , image_proj_layer in zip (
418+ ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
419+ ):
420+ output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
421+ single_image_embeds , single_negative_image_embeds = self .encode_image (
422+ single_ip_adapter_image , device , 1 , output_hidden_state
423+ )
424+ single_image_embeds = torch .stack ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
425+ single_negative_image_embeds = torch .stack (
426+ [single_negative_image_embeds ] * num_images_per_prompt , dim = 0
427+ )
428+
429+ if do_classifier_free_guidance :
430+ single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ])
431+ single_image_embeds = single_image_embeds .to (device )
432+
433+ image_embeds .append (single_image_embeds )
434+ else :
435+ image_embeds = ip_adapter_image_embeds
436+ return image_embeds
437+
404438 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
405439 def run_safety_checker (self , image , device , dtype ):
406440 if self .safety_checker is None :
@@ -535,6 +569,7 @@ def __call__(
535569 prompt_embeds : Optional [torch .FloatTensor ] = None ,
536570 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
537571 ip_adapter_image : Optional [PipelineImageInput ] = None ,
572+ ip_adapter_image_embeds : Optional [List [torch .FloatTensor ]] = None ,
538573 output_type : Optional [str ] = "pil" ,
539574 return_dict : bool = True ,
540575 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -583,6 +618,9 @@ def __call__(
583618 not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
584619 ip_adapter_image: (`PipelineImageInput`, *optional*):
585620 Optional image input to work with IP Adapters.
621+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
622+ Pre-generated image embeddings for IP-Adapter. If not
623+ provided, embeddings are computed from the `ip_adapter_image` input argument.
586624 output_type (`str`, *optional*, defaults to `"pil"`):
587625 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
588626 return_dict (`bool`, *optional*, defaults to `True`):
@@ -636,13 +674,24 @@ def __call__(
636674 # `sag_scale = 0` means no self-attention guidance
637675 do_self_attention_guidance = sag_scale > 0.0
638676
639- if ip_adapter_image is not None :
640- output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
641- image_embeds , negative_image_embeds = self .encode_image (
642- ip_adapter_image , device , num_images_per_prompt , output_hidden_state
677+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
678+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
679+ ip_adapter_image ,
680+ ip_adapter_image_embeds ,
681+ device ,
682+ batch_size * num_images_per_prompt ,
683+ do_classifier_free_guidance ,
643684 )
685+
644686 if do_classifier_free_guidance :
645- image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
687+ image_embeds = []
688+ negative_image_embeds = []
689+ for tmp_image_embeds in ip_adapter_image_embeds :
690+ single_negative_image_embeds , single_image_embeds = tmp_image_embeds .chunk (2 )
691+ image_embeds .append (single_image_embeds )
692+ negative_image_embeds .append (single_negative_image_embeds )
693+ else :
694+ image_embeds = ip_adapter_image_embeds
646695
647696 # 3. Encode input prompt
648697 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
@@ -687,8 +736,18 @@ def __call__(
687736 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
688737
689738 # 6.1 Add image embeds for IP-Adapter
690- added_cond_kwargs = {"image_embeds" : image_embeds } if ip_adapter_image is not None else None
691- added_uncond_kwargs = {"image_embeds" : negative_image_embeds } if ip_adapter_image is not None else None
739+ added_cond_kwargs = (
740+ {"image_embeds" : image_embeds }
741+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
742+ else None
743+ )
744+
745+ if do_classifier_free_guidance :
746+ added_uncond_kwargs = (
747+ {"image_embeds" : negative_image_embeds }
748+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
749+ else None
750+ )
692751
693752 # 7. Denoising loop
694753 store_processor = CrossAttnStoreProcessor ()
0 commit comments