@@ -196,11 +196,13 @@ def enable_model_cpu_offload(self, gpu_id=0):
196196 # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
197197 def encode_prompt (
198198 self ,
199- prompt ,
199+ prompt : str ,
200+ prompt_2 : Optional [str ] = None ,
200201 device : Optional [torch .device ] = None ,
201202 num_images_per_prompt : int = 1 ,
202203 do_classifier_free_guidance : bool = True ,
203- negative_prompt = None ,
204+ negative_prompt : Optional [str ] = None ,
205+ negative_prompt_2 : Optional [str ] = None ,
204206 prompt_embeds : Optional [torch .FloatTensor ] = None ,
205207 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
206208 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
@@ -211,8 +213,11 @@ def encode_prompt(
211213 Encodes the prompt into text encoder hidden states.
212214
213215 Args:
214- prompt (`str` or `List[str]`, *optional*):
216+ prompt (`str` or `List[str]`, *optional*):
215217 prompt to be encoded
218+ prompt_2 (`str` or `List[str]`, *optional*):
219+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
220+ used in both text-encoders
216221 device: (`torch.device`):
217222 torch device
218223 num_images_per_prompt (`int`):
@@ -223,6 +228,9 @@ def encode_prompt(
223228 The prompt or prompts not to guide the image generation. If not defined, one has to pass
224229 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
225230 less than `1`).
231+ negative_prompt_2 (`str` or `List[str]`, *optional*):
232+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
233+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
226234 prompt_embeds (`torch.FloatTensor`, *optional*):
227235 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
228236 provided, text embeddings will be generated from `prompt` input argument.
@@ -261,9 +269,11 @@ def encode_prompt(
261269 )
262270
263271 if prompt_embeds is None :
272+ prompt_2 = prompt_2 or prompt
264273 # textual inversion: procecss multi-vector tokens if necessary
265274 prompt_embeds_list = []
266- for tokenizer , text_encoder in zip (tokenizers , text_encoders ):
275+ prompts = [prompt , prompt_2 ]
276+ for prompt , tokenizer , text_encoder in zip (prompts , tokenizers , text_encoders ):
267277 if isinstance (self , TextualInversionLoaderMixin ):
268278 prompt = self .maybe_convert_prompt (prompt , tokenizer )
269279
@@ -274,8 +284,10 @@ def encode_prompt(
274284 truncation = True ,
275285 return_tensors = "pt" ,
276286 )
287+
277288 text_input_ids = text_inputs .input_ids
278289 untruncated_ids = tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
290+ untruncated_ids = tokenizer (prompt , padding = "longest" , return_tensors = "pt" ).input_ids
279291
280292 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (
281293 text_input_ids , untruncated_ids
@@ -311,32 +323,33 @@ def encode_prompt(
311323 negative_pooled_prompt_embeds = torch .zeros_like (pooled_prompt_embeds )
312324 elif do_classifier_free_guidance and negative_prompt_embeds is None :
313325 negative_prompt = negative_prompt or ""
326+ negative_prompt_2 = negative_prompt_2 or negative_prompt
327+
314328 uncond_tokens : List [str ]
315329 if prompt is not None and type (prompt ) is not type (negative_prompt ):
316330 raise TypeError (
317331 f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
318332 f" { type (prompt )} ."
319333 )
320334 elif isinstance (negative_prompt , str ):
321- uncond_tokens = [negative_prompt ]
335+ uncond_tokens = [negative_prompt , negative_prompt_2 ]
322336 elif batch_size != len (negative_prompt ):
323337 raise ValueError (
324338 f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
325339 f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
326340 " the batch size of `prompt`."
327341 )
328342 else :
329- uncond_tokens = negative_prompt
343+ uncond_tokens = [ negative_prompt , negative_prompt_2 ]
330344
331345 negative_prompt_embeds_list = []
332- for tokenizer , text_encoder in zip (tokenizers , text_encoders ):
333- # textual inversion: procecss multi-vector tokens if necessary
346+ for negative_prompt , tokenizer , text_encoder in zip (uncond_tokens , tokenizers , text_encoders ):
334347 if isinstance (self , TextualInversionLoaderMixin ):
335- uncond_tokens = self .maybe_convert_prompt (uncond_tokens , tokenizer )
348+ negative_prompt = self .maybe_convert_prompt (negative_prompt , tokenizer )
336349
337350 max_length = prompt_embeds .shape [1 ]
338351 uncond_input = tokenizer (
339- uncond_tokens ,
352+ negative_prompt ,
340353 padding = "max_length" ,
341354 max_length = max_length ,
342355 truncation = True ,
@@ -401,9 +414,11 @@ def prepare_extra_step_kwargs(self, generator, eta):
401414 def check_inputs (
402415 self ,
403416 prompt ,
417+ prompt_2 ,
404418 image ,
405419 callback_steps ,
406420 negative_prompt = None ,
421+ negative_prompt_2 = None ,
407422 prompt_embeds = None ,
408423 negative_prompt_embeds = None ,
409424 controlnet_conditioning_scale = 1.0 ,
@@ -423,18 +438,30 @@ def check_inputs(
423438 f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
424439 " only forward one of the two."
425440 )
441+ elif prompt_2 is not None and prompt_embeds is not None :
442+ raise ValueError (
443+ f"Cannot forward both `prompt_2`: { prompt_2 } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
444+ " only forward one of the two."
445+ )
426446 elif prompt is None and prompt_embeds is None :
427447 raise ValueError (
428448 "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
429449 )
430450 elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
431451 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
452+ elif prompt_2 is not None and (not isinstance (prompt_2 , str ) and not isinstance (prompt_2 , list )):
453+ raise ValueError (f"`prompt_2` has to be of type `str` or `list` but is { type (prompt_2 )} " )
432454
433455 if negative_prompt is not None and negative_prompt_embeds is not None :
434456 raise ValueError (
435457 f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
436458 f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
437459 )
460+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None :
461+ raise ValueError (
462+ f"Cannot forward both `negative_prompt_2`: { negative_prompt_2 } and `negative_prompt_embeds`:"
463+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
464+ )
438465
439466 if prompt_embeds is not None and negative_prompt_embeds is not None :
440467 if prompt_embeds .shape != negative_prompt_embeds .shape :
@@ -610,6 +637,7 @@ def upcast_vae(self):
610637 def __call__ (
611638 self ,
612639 prompt : Union [str , List [str ]] = None ,
640+ prompt_2 : Optional [Union [str , List [str ]]] = None ,
613641 image : Union [
614642 torch .FloatTensor ,
615643 PIL .Image .Image ,
@@ -623,6 +651,7 @@ def __call__(
623651 num_inference_steps : int = 50 ,
624652 guidance_scale : float = 7.5 ,
625653 negative_prompt : Optional [Union [str , List [str ]]] = None ,
654+ negative_prompt_2 : Optional [Union [str , List [str ]]] = None ,
626655 num_images_per_prompt : Optional [int ] = 1 ,
627656 eta : float = 0.0 ,
628657 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
@@ -649,6 +678,9 @@ def __call__(
649678 prompt (`str` or `List[str]`, *optional*):
650679 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
651680 instead.
681+ prompt_2 (`str` or `List[str]`, *optional*):
682+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
683+ used in both text-encoders
652684 image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
653685 `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
654686 The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
@@ -674,6 +706,9 @@ def __call__(
674706 The prompt or prompts not to guide the image generation. If not defined, one has to pass
675707 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
676708 less than `1`).
709+ negative_prompt_2 (`str` or `List[str]`, *optional*):
710+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
711+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
677712 num_images_per_prompt (`int`, *optional*, defaults to 1):
678713 The number of images to generate per prompt.
679714 eta (`float`, *optional*, defaults to 0.0):
@@ -749,9 +784,11 @@ def __call__(
749784 # 1. Check inputs. Raise error if not correct
750785 self .check_inputs (
751786 prompt ,
787+ prompt_2 ,
752788 image ,
753789 callback_steps ,
754790 negative_prompt ,
791+ negative_prompt_2 ,
755792 prompt_embeds ,
756793 negative_prompt_embeds ,
757794 controlnet_conditioning_scale ,
@@ -791,10 +828,12 @@ def __call__(
791828 negative_pooled_prompt_embeds ,
792829 ) = self .encode_prompt (
793830 prompt ,
831+ prompt_2 ,
794832 device ,
795833 num_images_per_prompt ,
796834 do_classifier_free_guidance ,
797835 negative_prompt ,
836+ negative_prompt_2 ,
798837 prompt_embeds = prompt_embeds ,
799838 negative_prompt_embeds = negative_prompt_embeds ,
800839 lora_scale = text_encoder_lora_scale ,
0 commit comments