@@ -573,6 +573,13 @@ def parse_args(input_args=None):
573573 default = 1e-4 ,
574574 help = "Initial learning rate (after the potential warmup period) to use." ,
575575 )
576+ parser .add_argument (
577+ "--clip_skip" ,
578+ type = int ,
579+ default = None ,
580+ help = "Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
581+ "the output of the pre-final layer will be used for computing the prompt embeddings." ,
582+ )
576583
577584 parser .add_argument (
578585 "--text_encoder_lr" ,
@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
12361243
12371244
12381245# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
1239- def encode_prompt (text_encoders , tokenizers , prompt , text_input_ids_list = None ):
1246+ def encode_prompt (text_encoders , tokenizers , prompt , text_input_ids_list = None , clip_skip = None ):
12401247 prompt_embeds_list = []
12411248
12421249 for i , text_encoder in enumerate (text_encoders ):
@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
12531260
12541261 # We are only ALWAYS interested in the pooled output of the final text encoder
12551262 pooled_prompt_embeds = prompt_embeds [0 ]
1256- prompt_embeds = prompt_embeds [- 1 ][- 2 ]
1263+ if clip_skip is None :
1264+ prompt_embeds = prompt_embeds [- 1 ][- 2 ]
1265+ else :
1266+ # "2" because SDXL always indexes from the penultimate layer.
1267+ prompt_embeds = prompt_embeds [- 1 ][- (clip_skip + 2 )]
12571268 bs_embed , seq_len , _ = prompt_embeds .shape
12581269 prompt_embeds = prompt_embeds .view (bs_embed , seq_len , - 1 )
12591270 prompt_embeds_list .append (prompt_embeds )
@@ -1830,9 +1841,9 @@ def compute_time_ids(crops_coords_top_left, original_size=None):
18301841 tokenizers = [tokenizer_one , tokenizer_two ]
18311842 text_encoders = [text_encoder_one , text_encoder_two ]
18321843
1833- def compute_text_embeddings (prompt , text_encoders , tokenizers ):
1844+ def compute_text_embeddings (prompt , text_encoders , tokenizers , clip_skip ):
18341845 with torch .no_grad ():
1835- prompt_embeds , pooled_prompt_embeds = encode_prompt (text_encoders , tokenizers , prompt )
1846+ prompt_embeds , pooled_prompt_embeds = encode_prompt (text_encoders , tokenizers , prompt , clip_skip )
18361847 prompt_embeds = prompt_embeds .to (accelerator .device )
18371848 pooled_prompt_embeds = pooled_prompt_embeds .to (accelerator .device )
18381849 return prompt_embeds , pooled_prompt_embeds
@@ -1842,7 +1853,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18421853 # the redundant encoding.
18431854 if freeze_text_encoder and not train_dataset .custom_instance_prompts :
18441855 instance_prompt_hidden_states , instance_pooled_prompt_embeds = compute_text_embeddings (
1845- args .instance_prompt , text_encoders , tokenizers
1856+ args .instance_prompt , text_encoders , tokenizers , args . clip_skip
18461857 )
18471858
18481859 # Handle class prompt for prior-preservation.
@@ -2052,7 +2063,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20522063 if train_dataset .custom_instance_prompts :
20532064 if freeze_text_encoder :
20542065 prompt_embeds , unet_add_text_embeds = compute_text_embeddings (
2055- prompts , text_encoders , tokenizers
2066+ prompts , text_encoders , tokenizers , args . clip_skip
20562067 )
20572068
20582069 else :
@@ -2147,6 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21472158 tokenizers = None ,
21482159 prompt = None ,
21492160 text_input_ids_list = [tokens_one , tokens_two ],
2161+ clip_skip = args .clip_skip ,
21502162 )
21512163 unet_added_conditions .update (
21522164 {"text_embeds" : pooled_prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 )}
0 commit comments