1818import numpy as np
1919import PIL .Image
2020import torch
21- from transformers import CLIPTextModel , CLIPTextModelWithProjection , CLIPTokenizer
21+ from transformers import (
22+ CLIPImageProcessor ,
23+ CLIPTextModel ,
24+ CLIPTextModelWithProjection ,
25+ CLIPTokenizer ,
26+ CLIPVisionModelWithProjection ,
27+ )
2228
23- from ...image_processor import VaeImageProcessor
24- from ...loaders import FromSingleFileMixin , StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
25- from ...models import AutoencoderKL , MultiAdapter , T2IAdapter , UNet2DConditionModel
29+ from ...image_processor import PipelineImageInput , VaeImageProcessor
30+ from ...loaders import (
31+ FromSingleFileMixin ,
32+ IPAdapterMixin ,
33+ StableDiffusionXLLoraLoaderMixin ,
34+ TextualInversionLoaderMixin ,
35+ )
36+ from ...models import AutoencoderKL , ImageProjection , MultiAdapter , T2IAdapter , UNet2DConditionModel
2637from ...models .attention_processor import (
2738 AttnProcessor2_0 ,
2839 LoRAAttnProcessor2_0 ,
@@ -169,7 +180,11 @@ def retrieve_timesteps(
169180
170181
171182class StableDiffusionXLAdapterPipeline (
172- DiffusionPipeline , FromSingleFileMixin , StableDiffusionXLLoraLoaderMixin , TextualInversionLoaderMixin
183+ DiffusionPipeline ,
184+ TextualInversionLoaderMixin ,
185+ StableDiffusionXLLoraLoaderMixin ,
186+ IPAdapterMixin ,
187+ FromSingleFileMixin ,
173188):
174189 r"""
175190 Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
@@ -183,6 +198,7 @@ class StableDiffusionXLAdapterPipeline(
183198 - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
184199 - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
185200 - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
201+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
186202
187203 Args:
188204 adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
@@ -211,8 +227,15 @@ class StableDiffusionXLAdapterPipeline(
211227 Model that extracts features from generated images to be used as inputs for the `safety_checker`.
212228 """
213229
214- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
215- _optional_components = ["tokenizer" , "tokenizer_2" , "text_encoder" , "text_encoder_2" ]
230+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
231+ _optional_components = [
232+ "tokenizer" ,
233+ "tokenizer_2" ,
234+ "text_encoder" ,
235+ "text_encoder_2" ,
236+ "feature_extractor" ,
237+ "image_encoder" ,
238+ ]
216239
217240 def __init__ (
218241 self ,
@@ -225,6 +248,8 @@ def __init__(
225248 adapter : Union [T2IAdapter , MultiAdapter , List [T2IAdapter ]],
226249 scheduler : KarrasDiffusionSchedulers ,
227250 force_zeros_for_empty_prompt : bool = True ,
251+ feature_extractor : CLIPImageProcessor = None ,
252+ image_encoder : CLIPVisionModelWithProjection = None ,
228253 ):
229254 super ().__init__ ()
230255
@@ -237,6 +262,8 @@ def __init__(
237262 unet = unet ,
238263 adapter = adapter ,
239264 scheduler = scheduler ,
265+ feature_extractor = feature_extractor ,
266+ image_encoder = image_encoder ,
240267 )
241268 self .register_to_config (force_zeros_for_empty_prompt = force_zeros_for_empty_prompt )
242269 self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
@@ -511,6 +538,31 @@ def encode_prompt(
511538
512539 return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
513540
541+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
542+ def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
543+ dtype = next (self .image_encoder .parameters ()).dtype
544+
545+ if not isinstance (image , torch .Tensor ):
546+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
547+
548+ image = image .to (device = device , dtype = dtype )
549+ if output_hidden_states :
550+ image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
551+ image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
552+ uncond_image_enc_hidden_states = self .image_encoder (
553+ torch .zeros_like (image ), output_hidden_states = True
554+ ).hidden_states [- 2 ]
555+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
556+ num_images_per_prompt , dim = 0
557+ )
558+ return image_enc_hidden_states , uncond_image_enc_hidden_states
559+ else :
560+ image_embeds = self .image_encoder (image ).image_embeds
561+ image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
562+ uncond_image_embeds = torch .zeros_like (image_embeds )
563+
564+ return image_embeds , uncond_image_embeds
565+
514566 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
515567 def prepare_extra_step_kwargs (self , generator , eta ):
516568 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -768,7 +820,7 @@ def __call__(
768820 self ,
769821 prompt : Union [str , List [str ]] = None ,
770822 prompt_2 : Optional [Union [str , List [str ]]] = None ,
771- image : Union [ torch . Tensor , PIL . Image . Image , List [ PIL . Image . Image ]] = None ,
823+ image : PipelineImageInput = None ,
772824 height : Optional [int ] = None ,
773825 width : Optional [int ] = None ,
774826 num_inference_steps : int = 50 ,
@@ -785,6 +837,7 @@ def __call__(
785837 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
786838 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
787839 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
840+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
788841 output_type : Optional [str ] = "pil" ,
789842 return_dict : bool = True ,
790843 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -876,6 +929,7 @@ def __call__(
876929 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
877930 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
878931 input argument.
932+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
879933 output_type (`str`, *optional*, defaults to `"pil"`):
880934 The output format of the generate image. Choose between
881935 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -991,7 +1045,7 @@ def __call__(
9911045
9921046 device = self ._execution_device
9931047
994- # 3. Encode input prompt
1048+ # 3.1 Encode input prompt
9951049 (
9961050 prompt_embeds ,
9971051 negative_prompt_embeds ,
@@ -1012,6 +1066,15 @@ def __call__(
10121066 clip_skip = clip_skip ,
10131067 )
10141068
1069+ # 3.2 Encode ip_adapter_image
1070+ if ip_adapter_image is not None :
1071+ output_hidden_state = False if isinstance (self .unet .encoder_hid_proj , ImageProjection ) else True
1072+ image_embeds , negative_image_embeds = self .encode_image (
1073+ ip_adapter_image , device , num_images_per_prompt , output_hidden_state
1074+ )
1075+ if self .do_classifier_free_guidance :
1076+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ])
1077+
10151078 # 4. Prepare timesteps
10161079 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
10171080
@@ -1028,10 +1091,10 @@ def __call__(
10281091 latents ,
10291092 )
10301093
1031- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1094+ # 6.1 Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
10321095 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
10331096
1034- # 6.5 Optionally get Guidance Scale Embedding
1097+ # 6.2 Optionally get Guidance Scale Embedding
10351098 timestep_cond = None
10361099 if self .unet .config .time_cond_proj_dim is not None :
10371100 guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
@@ -1090,8 +1153,7 @@ def __call__(
10901153
10911154 # 8. Denoising loop
10921155 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
1093-
1094- # 7.1 Apply denoising_end
1156+ # Apply denoising_end
10951157 if denoising_end is not None and isinstance (denoising_end , float ) and denoising_end > 0 and denoising_end < 1 :
10961158 discrete_timestep_cutoff = int (
10971159 round (
@@ -1109,9 +1171,12 @@ def __call__(
11091171
11101172 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
11111173
1112- # predict the noise residual
11131174 added_cond_kwargs = {"text_embeds" : add_text_embeds , "time_ids" : add_time_ids }
11141175
1176+ if ip_adapter_image is not None :
1177+ added_cond_kwargs ["image_embeds" ] = image_embeds
1178+
1179+ # predict the noise residual
11151180 if i < int (num_inference_steps * adapter_conditioning_factor ):
11161181 down_intrablock_additional_residuals = [state .clone () for state in adapter_state ]
11171182 else :
@@ -1123,9 +1188,9 @@ def __call__(
11231188 encoder_hidden_states = prompt_embeds ,
11241189 timestep_cond = timestep_cond ,
11251190 cross_attention_kwargs = cross_attention_kwargs ,
1191+ down_intrablock_additional_residuals = down_intrablock_additional_residuals ,
11261192 added_cond_kwargs = added_cond_kwargs ,
11271193 return_dict = False ,
1128- down_intrablock_additional_residuals = down_intrablock_additional_residuals ,
11291194 )[0 ]
11301195
11311196 # perform guidance
0 commit comments