@@ -962,7 +962,7 @@ def encode_prompt(
962962 prompt = prompt ,
963963 device = device if device is not None else text_encoder .device ,
964964 num_images_per_prompt = num_images_per_prompt ,
965- text_input_ids = text_input_ids_list [i ],
965+ text_input_ids = text_input_ids_list [i ] if text_input_ids_list else None ,
966966 )
967967 clip_prompt_embeds_list .append (prompt_embeds )
968968 clip_pooled_prompt_embeds_list .append (pooled_prompt_embeds )
@@ -976,7 +976,7 @@ def encode_prompt(
976976 max_sequence_length ,
977977 prompt = prompt ,
978978 num_images_per_prompt = num_images_per_prompt ,
979- text_input_ids = text_input_ids_list [: - 1 ],
979+ text_input_ids = text_input_ids_list [- 1 ] if text_input_ids_list else None ,
980980 device = device if device is not None else text_encoders [- 1 ].device ,
981981 )
982982
@@ -1491,6 +1491,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14911491 ) = accelerator .prepare (
14921492 transformer , text_encoder_one , text_encoder_two , optimizer , train_dataloader , lr_scheduler
14931493 )
1494+ assert text_encoder_one is not None
1495+ assert text_encoder_two is not None
1496+ assert text_encoder_three is not None
14941497 else :
14951498 transformer , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
14961499 transformer , optimizer , train_dataloader , lr_scheduler
@@ -1598,7 +1601,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15981601 tokens_three = tokenize_prompt (tokenizer_three , prompts )
15991602 prompt_embeds , pooled_prompt_embeds = encode_prompt (
16001603 text_encoders = [text_encoder_one , text_encoder_two , text_encoder_three ],
1601- tokenizers = [None , None , tokenizer_three ],
1604+ tokenizers = [None , None , None ],
16021605 prompt = prompts ,
16031606 max_sequence_length = args .max_sequence_length ,
16041607 text_input_ids_list = [tokens_one , tokens_two , tokens_three ],
@@ -1608,7 +1611,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16081611 prompt_embeds , pooled_prompt_embeds = encode_prompt (
16091612 text_encoders = [text_encoder_one , text_encoder_two , text_encoder_three ],
16101613 tokenizers = [None , None , tokenizer_three ],
1611- prompt = prompts ,
1614+ prompt = args . instance_prompt ,
16121615 max_sequence_length = args .max_sequence_length ,
16131616 text_input_ids_list = [tokens_one , tokens_two , tokens_three ],
16141617 )
@@ -1685,10 +1688,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16851688
16861689 accelerator .backward (loss )
16871690 if accelerator .sync_gradients :
1688- params_to_clip = itertools .chain (
1689- transformer_lora_parameters ,
1690- text_lora_parameters_one ,
1691- text_lora_parameters_two if args .train_text_encoder else transformer_lora_parameters ,
1691+ params_to_clip = (
1692+ itertools .chain (
1693+ transformer_lora_parameters , text_lora_parameters_one , text_lora_parameters_two
1694+ )
1695+ if args .train_text_encoder
1696+ else transformer_lora_parameters
16921697 )
16931698 accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
16941699
@@ -1741,13 +1746,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17411746 text_encoder_one , text_encoder_two , text_encoder_three = load_text_encoders (
17421747 text_encoder_cls_one , text_encoder_cls_two , text_encoder_cls_three
17431748 )
1744- else :
1745- text_encoder_three = text_encoder_cls_three .from_pretrained (
1746- args .pretrained_model_name_or_path ,
1747- subfolder = "text_encoder_3" ,
1748- revision = args .revision ,
1749- variant = args .variant ,
1750- )
17511749 pipeline = StableDiffusion3Pipeline .from_pretrained (
17521750 args .pretrained_model_name_or_path ,
17531751 vae = vae ,
@@ -1767,7 +1765,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17671765 pipeline_args = pipeline_args ,
17681766 epoch = epoch ,
17691767 )
1770- del text_encoder_one , text_encoder_two , text_encoder_three
1768+ if not args .train_text_encoder :
1769+ del text_encoder_one , text_encoder_two , text_encoder_three
1770+
17711771 torch .cuda .empty_cache ()
17721772 gc .collect ()
17731773
0 commit comments