@@ -220,6 +220,7 @@ def _get_t5_prompt_embeds(
220220 self ,
221221 prompt : Union [str , List [str ]] = None ,
222222 num_images_per_prompt : int = 1 ,
223+ max_sequence_length : int = 256 ,
223224 device : Optional [torch .device ] = None ,
224225 dtype : Optional [torch .dtype ] = None ,
225226 ):
@@ -239,7 +240,7 @@ def _get_t5_prompt_embeds(
239240 text_inputs = self .tokenizer_3 (
240241 prompt ,
241242 padding = "max_length" ,
242- max_length = self . tokenizer_max_length ,
243+ max_length = max_sequence_length ,
243244 truncation = True ,
244245 add_special_tokens = True ,
245246 return_tensors = "pt" ,
@@ -250,8 +251,8 @@ def _get_t5_prompt_embeds(
250251 if untruncated_ids .shape [- 1 ] >= text_input_ids .shape [- 1 ] and not torch .equal (text_input_ids , untruncated_ids ):
251252 removed_text = self .tokenizer_3 .batch_decode (untruncated_ids [:, self .tokenizer_max_length - 1 : - 1 ])
252253 logger .warning (
253- "The following part of your input was truncated because CLIP can only handle sequences up to "
254- f" { self . tokenizer_max_length } tokens: { removed_text } "
254+ "The following part of your input was truncated because `max_sequence_length` is set to "
255+ f" { max_sequence_length } tokens: { removed_text } "
255256 )
256257
257258 prompt_embeds = self .text_encoder_3 (text_input_ids .to (device ))[0 ]
@@ -340,6 +341,7 @@ def encode_prompt(
340341 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
341342 negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
342343 clip_skip : Optional [int ] = None ,
344+ max_sequence_length : int = 256 ,
343345 ):
344346 r"""
345347
@@ -420,6 +422,7 @@ def encode_prompt(
420422 t5_prompt_embed = self ._get_t5_prompt_embeds (
421423 prompt = prompt_3 ,
422424 num_images_per_prompt = num_images_per_prompt ,
425+ max_sequence_length = max_sequence_length ,
423426 device = device ,
424427 )
425428
@@ -473,7 +476,10 @@ def encode_prompt(
473476 negative_clip_prompt_embeds = torch .cat ([negative_prompt_embed , negative_prompt_2_embed ], dim = - 1 )
474477
475478 t5_negative_prompt_embed = self ._get_t5_prompt_embeds (
476- prompt = negative_prompt_3 , num_images_per_prompt = num_images_per_prompt , device = device
479+ prompt = negative_prompt_3 ,
480+ num_images_per_prompt = num_images_per_prompt ,
481+ max_sequence_length = max_sequence_length ,
482+ device = device ,
477483 )
478484
479485 negative_clip_prompt_embeds = torch .nn .functional .pad (
@@ -502,6 +508,7 @@ def check_inputs(
502508 pooled_prompt_embeds = None ,
503509 negative_pooled_prompt_embeds = None ,
504510 callback_on_step_end_tensor_inputs = None ,
511+ max_sequence_length = None ,
505512 ):
506513 if strength < 0 or strength > 1 :
507514 raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
@@ -573,6 +580,9 @@ def check_inputs(
573580 "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
574581 )
575582
583+ if max_sequence_length is not None and max_sequence_length > 512 :
584+ raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
585+
576586 def get_timesteps (self , num_inference_steps , strength , device ):
577587 # get the original timestep using init_timestep
578588 init_timestep = min (num_inference_steps * strength , num_inference_steps )
@@ -684,6 +694,7 @@ def __call__(
684694 clip_skip : Optional [int ] = None ,
685695 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
686696 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
697+ max_sequence_length : int = 256 ,
687698 ):
688699 r"""
689700 Function invoked when calling the pipeline for generation.
@@ -763,6 +774,7 @@ def __call__(
763774 The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
764775 will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
765776 `._callback_tensor_inputs` attribute of your pipeline class.
777+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
766778
767779 Examples:
768780
@@ -786,6 +798,7 @@ def __call__(
786798 pooled_prompt_embeds = pooled_prompt_embeds ,
787799 negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
788800 callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
801+ max_sequence_length = max_sequence_length ,
789802 )
790803
791804 self ._guidance_scale = guidance_scale
@@ -822,6 +835,7 @@ def __call__(
822835 device = device ,
823836 clip_skip = self .clip_skip ,
824837 num_images_per_prompt = num_images_per_prompt ,
838+ max_sequence_length = max_sequence_length ,
825839 )
826840
827841 if self .do_classifier_free_guidance :
0 commit comments