2020import PIL .Image
2121import torch
2222import torch .nn .functional as F
23- from transformers import CLIPTextModel , CLIPTextModelWithProjection , CLIPTokenizer
23+ from transformers import (
24+ CLIPImageProcessor ,
25+ CLIPTextModel ,
26+ CLIPTextModelWithProjection ,
27+ CLIPTokenizer ,
28+ CLIPVisionModelWithProjection ,
29+ )
2430
2531from diffusers .utils .import_utils import is_invisible_watermark_available
2632
2733from ...image_processor import PipelineImageInput , VaeImageProcessor
28- from ...loaders import StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
29- from ...models import AutoencoderKL , ControlNetModel , UNet2DConditionModel
34+ from ...loaders import (
35+ IPAdapterMixin ,
36+ StableDiffusionXLLoraLoaderMixin ,
37+ TextualInversionLoaderMixin ,
38+ )
39+ from ...models import AutoencoderKL , ControlNetModel , ImageProjection , UNet2DConditionModel
3040from ...models .attention_processor import (
3141 AttnProcessor2_0 ,
3242 LoRAAttnProcessor2_0 ,
@@ -147,7 +157,7 @@ def retrieve_latents(
147157
148158
149159class StableDiffusionXLControlNetImg2ImgPipeline (
150- DiffusionPipeline , TextualInversionLoaderMixin , StableDiffusionXLLoraLoaderMixin
160+ DiffusionPipeline , TextualInversionLoaderMixin , StableDiffusionXLLoraLoaderMixin , IPAdapterMixin
151161):
152162 r"""
153163 Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -159,6 +169,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
159169 - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
160170 - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
161171 - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
172+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
162173
163174 Args:
164175 vae ([`AutoencoderKL`]):
@@ -197,10 +208,19 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
197208 Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
198209 watermark output images. If not defined, it will default to True if the package is installed, otherwise no
199210 watermarker will be used.
211+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
212+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
200213 """
201214
202- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
203- _optional_components = ["tokenizer" , "tokenizer_2" , "text_encoder" , "text_encoder_2" ]
215+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
216+ _optional_components = [
217+ "tokenizer" ,
218+ "tokenizer_2" ,
219+ "text_encoder" ,
220+ "text_encoder_2" ,
221+ "feature_extractor" ,
222+ "image_encoder" ,
223+ ]
204224 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
205225
206226 def __init__ (
@@ -216,6 +236,8 @@ def __init__(
216236 requires_aesthetics_score : bool = False ,
217237 force_zeros_for_empty_prompt : bool = True ,
218238 add_watermarker : Optional [bool ] = None ,
239+ feature_extractor : CLIPImageProcessor = None ,
240+ image_encoder : CLIPVisionModelWithProjection = None ,
219241 ):
220242 super ().__init__ ()
221243
@@ -231,6 +253,8 @@ def __init__(
231253 unet = unet ,
232254 controlnet = controlnet ,
233255 scheduler = scheduler ,
256+ feature_extractor = feature_extractor ,
257+ image_encoder = image_encoder ,
234258 )
235259 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
236260 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor , do_convert_rgb = True )
@@ -515,6 +539,31 @@ def encode_prompt(
515539
516540 return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
517541
542+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
543+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
544+ dtype = next (self .image_encoder .parameters ()).dtype
545+
546+ if not isinstance (image , torch .Tensor ):
547+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
548+
549+ image = image .to (device = device , dtype = dtype )
550+ if output_hidden_states :
551+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
552+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
553+ uncond_image_enc_hidden_states = self .image_encoder (
554+ torch .zeros_like (image ), output_hidden_states = True
555+ ).hidden_states [- 2 ]
556+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
557+ num_images_per_prompt , dim = 0
558+ )
559+ return image_enc_hidden_states , uncond_image_enc_hidden_states
560+ else :
561+ image_embeds = self .image_encoder (image ).image_embeds
562+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
563+ uncond_image_embeds = torch .zeros_like (image_embeds )
564+
565+ return image_embeds , uncond_image_embeds
566+
518567 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
519568 def prepare_extra_step_kwargs (self , generator , eta ):
520569 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -1011,6 +1060,7 @@ def __call__(
10111060 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
10121061 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
10131062 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
1063+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
10141064 output_type : Optional [str ] = "pil" ,
10151065 return_dict : bool = True ,
10161066 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -1109,6 +1159,7 @@ def __call__(
11091159 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
11101160 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
11111161 input argument.
1162+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
11121163 output_type (`str`, *optional*, defaults to `"pil"`):
11131164 The output format of the generate image. Choose between
11141165 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1262,7 +1313,7 @@ def __call__(
12621313 )
12631314 guess_mode = guess_mode or global_pool_conditions
12641315
1265- # 3. Encode input prompt
1316+ # 3.1. Encode input prompt
12661317 text_encoder_lora_scale = (
12671318 self .cross_attention_kwargs .get ("scale" , None ) if self .cross_attention_kwargs is not None else None
12681319 )
@@ -1287,6 +1338,15 @@ def __call__(
12871338 clip_skip = self .clip_skip ,
12881339 )
12891340
1341+ # 3.2 Encode ip_adapter_image
1342+ if ip_adapter_image is not None :
1343+ output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
1344+ image_embeds , negative_image_embeds = self .encode_image (
1345+ ip_adapter_image , device , num_images_per_prompt , output_hidden_state
1346+ )
1347+ if self .do_classifier_free_guidance :
1348+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1349+
12901350 # 4. Prepare image and controlnet_conditioning_image
12911351 image = self .image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
12921352
@@ -1449,6 +1509,9 @@ def __call__(
14491509 down_block_res_samples = [torch .cat ([torch .zeros_like (d ), d ]) for d in down_block_res_samples ]
14501510 mid_block_res_sample = torch .cat ([torch .zeros_like (mid_block_res_sample ), mid_block_res_sample ])
14511511
1512+ if ip_adapter_image is not None :
1513+ added_cond_kwargs ["image_embeds" ] = image_embeds
1514+
14521515 # predict the noise residual
14531516 noise_pred = self .unet (
14541517 latent_model_input ,
0 commit comments