5454 UNet2DConditionModel ,
5555)
5656from diffusers .loaders import LoraLoaderMixin
57- from diffusers .models .lora import LoRALinearLayer , text_encoder_lora_state_dict
57+ from diffusers .models .lora import LoRALinearLayer
5858from diffusers .optimization import get_scheduler
5959from diffusers .training_utils import compute_snr , unet_lora_state_dict
6060from diffusers .utils import check_min_version , is_wandb_available
6767logger = get_logger (__name__ )
6868
6969
70+ # TODO: This function should be removed once training scripts are rewritten in PEFT
71+ def text_encoder_lora_state_dict (text_encoder ):
72+ state_dict = {}
73+
74+ def text_encoder_attn_modules (text_encoder ):
75+ from transformers import CLIPTextModel , CLIPTextModelWithProjection
76+
77+ attn_modules = []
78+
79+ if isinstance (text_encoder , (CLIPTextModel , CLIPTextModelWithProjection )):
80+ for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
81+ name = f"text_model.encoder.layers.{ i } .self_attn"
82+ mod = layer .self_attn
83+ attn_modules .append ((name , mod ))
84+
85+ return attn_modules
86+
87+ for name , module in text_encoder_attn_modules (text_encoder ):
88+ for k , v in module .q_proj .lora_linear_layer .state_dict ().items ():
89+ state_dict [f"{ name } .q_proj.lora_linear_layer.{ k } " ] = v
90+
91+ for k , v in module .k_proj .lora_linear_layer .state_dict ().items ():
92+ state_dict [f"{ name } .k_proj.lora_linear_layer.{ k } " ] = v
93+
94+ for k , v in module .v_proj .lora_linear_layer .state_dict ().items ():
95+ state_dict [f"{ name } .v_proj.lora_linear_layer.{ k } " ] = v
96+
97+ for k , v in module .out_proj .lora_linear_layer .state_dict ().items ():
98+ state_dict [f"{ name } .out_proj.lora_linear_layer.{ k } " ] = v
99+
100+ return state_dict
101+
102+
70103def save_model_card (
71104 repo_id : str ,
72105 images = None ,
73106 base_model = str ,
74107 train_text_encoder = False ,
108+ train_text_encoder_ti = False ,
109+ token_abstraction_dict = None ,
75110 instance_prompt = str ,
76111 validation_prompt = str ,
77112 repo_folder = None ,
@@ -83,10 +118,23 @@ def save_model_card(
83118 img_str += f"""
84119 - text: '{ validation_prompt if validation_prompt else ' ' } '
85120 output:
86- url: >-
121+ url:
87122 "image_{ i } .png"
88123 """
89124
125+ trigger_str = f"You should use { instance_prompt } to trigger the image generation."
126+ if train_text_encoder_ti :
127+ trigger_str = (
128+ "To trigger image generation of trained concept(or concepts) replace each concept identifier "
129+ "in you prompt with the new inserted tokens:\n "
130+ )
131+ if token_abstraction_dict :
132+ for key , value in token_abstraction_dict .items ():
133+ tokens = "" .join (value )
134+ trigger_str += f"""
135+ to trigger concept { key } -> use { tokens } in your prompt \n
136+ """
137+
90138 yaml = f"""
91139---
92140tags:
@@ -96,9 +144,7 @@ def save_model_card(
96144- diffusers
97145- lora
98146- template:sd-lora
99- widget:
100147{ img_str }
101- ---
102148base_model: { base_model }
103149instance_prompt: { instance_prompt }
104150license: openrail++
@@ -112,14 +158,19 @@ def save_model_card(
112158
113159## Model description
114160
115- These are { repo_id } LoRA adaption weights for { base_model } .
161+ ### These are { repo_id } LoRA adaption weights for { base_model } .
162+
116163The weights were trained using [DreamBooth](https://dreambooth.github.io/).
164+
117165LoRA for the text encoder was enabled: { train_text_encoder } .
166+
167+ Pivotal tuning was enabled: { train_text_encoder_ti } .
168+
118169Special VAE used for training: { vae_path } .
119170
120171## Trigger words
121172
122- You should use { instance_prompt } to trigger the image generation.
173+ { trigger_str }
123174
124175## Download model
125176
@@ -244,6 +295,7 @@ def parse_args(input_args=None):
244295
245296 parser .add_argument (
246297 "--num_new_tokens_per_abstraction" ,
298+ type = int ,
247299 default = 2 ,
248300 help = "number of new tokens inserted to the tokenizers per token_abstraction value when "
249301 "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
@@ -455,7 +507,7 @@ def parse_args(input_args=None):
455507 parser .add_argument (
456508 "--train_text_encoder_frac" ,
457509 type = float ,
458- default = 0.5 ,
510+ default = 1.0 ,
459511 help = ("The percentage of epochs to perform text encoder tuning" ),
460512 )
461513
@@ -488,7 +540,7 @@ def parse_args(input_args=None):
488540 parser .add_argument ("--prodigy_decouple" , type = bool , default = True , help = "Use AdamW style decoupled weight decay" )
489541 parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-04 , help = "Weight decay to use for unet params" )
490542 parser .add_argument (
491- "--adam_weight_decay_text_encoder" , type = float , default = 1e-03 , help = "Weight decay to use for text_encoder"
543+ "--adam_weight_decay_text_encoder" , type = float , default = None , help = "Weight decay to use for text_encoder"
492544 )
493545
494546 parser .add_argument (
@@ -679,12 +731,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
679731 def save_embeddings (self , file_path : str ):
680732 assert self .train_ids is not None , "Initialize new tokens before saving embeddings."
681733 tensors = {}
734+ # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
735+ idx_to_text_encoder_name = {0 : "clip_l" , 1 : "clip_g" }
682736 for idx , text_encoder in enumerate (self .text_encoders ):
683737 assert text_encoder .text_model .embeddings .token_embedding .weight .data .shape [0 ] == len (
684738 self .tokenizers [0 ]
685739 ), "Tokenizers should be the same."
686740 new_token_embeddings = text_encoder .text_model .embeddings .token_embedding .weight .data [self .train_ids ]
687- tensors [f"text_encoders_{ idx } " ] = new_token_embeddings
741+
742+ # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
743+ # text_encoder 1) to keep compatible with the ecosystem.
744+ # Note: When loading with diffusers, any name can work - simply specify in inference
745+ tensors [idx_to_text_encoder_name [idx ]] = new_token_embeddings
746+ # tensors[f"text_encoders_{idx}"] = new_token_embeddings
688747
689748 save_file (tensors , file_path )
690749
@@ -696,19 +755,6 @@ def dtype(self):
696755 def device (self ):
697756 return self .text_encoders [0 ].device
698757
699- # def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
700- # # Assuming new tokens are of the format <s_i>
701- # self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
702- # special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
703- # tokenizer.add_special_tokens(special_tokens_dict)
704- # text_encoder.resize_token_embeddings(len(tokenizer))
705- #
706- # self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
707- # assert self.train_ids is not None, "New tokens could not be converted to IDs."
708- # text_encoder.text_model.embeddings.token_embedding.weight.data[
709- # self.train_ids
710- # ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
711-
712758 @torch .no_grad ()
713759 def retract_embeddings (self ):
714760 for idx , text_encoder in enumerate (self .text_encoders ):
@@ -730,15 +776,6 @@ def retract_embeddings(self):
730776 new_embeddings = new_embeddings * (off_ratio ** 0.1 )
731777 text_encoder .text_model .embeddings .token_embedding .weight .data [index_updates ] = new_embeddings
732778
733- # def load_embeddings(self, file_path: str):
734- # with safe_open(file_path, framework="pt", device=self.device.type) as f:
735- # for idx in range(len(self.text_encoders)):
736- # text_encoder = self.text_encoders[idx]
737- # tokenizer = self.tokenizers[idx]
738- #
739- # loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
740- # self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
741-
742779
743780class DreamBoothDataset (Dataset ):
744781 """
@@ -1216,13 +1253,17 @@ def main(args):
12161253 text_lora_parameters_one = []
12171254 for name , param in text_encoder_one .named_parameters ():
12181255 if "token_embedding" in name :
1256+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1257+ param = param .to (dtype = torch .float32 )
12191258 param .requires_grad = True
12201259 text_lora_parameters_one .append (param )
12211260 else :
12221261 param .requires_grad = False
12231262 text_lora_parameters_two = []
12241263 for name , param in text_encoder_two .named_parameters ():
12251264 if "token_embedding" in name :
1265+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
1266+ param = param .to (dtype = torch .float32 )
12261267 param .requires_grad = True
12271268 text_lora_parameters_two .append (param )
12281269 else :
@@ -1309,12 +1350,16 @@ def load_model_hook(models, input_dir):
13091350 # different learning rate for text encoder and unet
13101351 text_lora_parameters_one_with_lr = {
13111352 "params" : text_lora_parameters_one ,
1312- "weight_decay" : args .adam_weight_decay_text_encoder ,
1353+ "weight_decay" : args .adam_weight_decay_text_encoder
1354+ if args .adam_weight_decay_text_encoder
1355+ else args .adam_weight_decay ,
13131356 "lr" : args .text_encoder_lr if args .text_encoder_lr else args .learning_rate ,
13141357 }
13151358 text_lora_parameters_two_with_lr = {
13161359 "params" : text_lora_parameters_two ,
1317- "weight_decay" : args .adam_weight_decay_text_encoder ,
1360+ "weight_decay" : args .adam_weight_decay_text_encoder
1361+ if args .adam_weight_decay_text_encoder
1362+ else args .adam_weight_decay ,
13181363 "lr" : args .text_encoder_lr if args .text_encoder_lr else args .learning_rate ,
13191364 }
13201365 params_to_optimize = [
@@ -1494,6 +1539,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14941539 tokens_one = torch .cat ([tokens_one , class_tokens_one ], dim = 0 )
14951540 tokens_two = torch .cat ([tokens_two , class_tokens_two ], dim = 0 )
14961541
1542+ if args .train_text_encoder_ti and args .validation_prompt :
1543+ # replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
1544+ for token_abs , token_replacement in train_dataset .token_abstraction_dict .items ():
1545+ args .validation_prompt = args .validation_prompt .replace (token_abs , "" .join (token_replacement ))
1546+ print ("validation prompt:" , args .validation_prompt )
1547+
14971548 # Scheduler and math around the number of training steps.
14981549 overrode_max_train_steps = False
14991550 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -1593,27 +1644,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
15931644 if epoch == num_train_epochs_text_encoder :
15941645 print ("PIVOT HALFWAY" , epoch )
15951646 # stopping optimization of text_encoder params
1596- params_to_optimize = params_to_optimize [:1 ]
1597- # reinitializing the optimizer to optimize only on unet params
1598- if args .optimizer .lower () == "prodigy" :
1599- optimizer = optimizer_class (
1600- params_to_optimize ,
1601- lr = args .learning_rate ,
1602- betas = (args .adam_beta1 , args .adam_beta2 ),
1603- beta3 = args .prodigy_beta3 ,
1604- weight_decay = args .adam_weight_decay ,
1605- eps = args .adam_epsilon ,
1606- decouple = args .prodigy_decouple ,
1607- use_bias_correction = args .prodigy_use_bias_correction ,
1608- safeguard_warmup = args .prodigy_safeguard_warmup ,
1609- )
1610- else : # AdamW or 8-bit-AdamW
1611- optimizer = optimizer_class (
1612- params_to_optimize ,
1613- betas = (args .adam_beta1 , args .adam_beta2 ),
1614- weight_decay = args .adam_weight_decay ,
1615- eps = args .adam_epsilon ,
1616- )
1647+ # re setting the optimizer to optimize only on unet params
1648+ optimizer .param_groups [1 ]["lr" ] = 0.0
1649+ optimizer .param_groups [2 ]["lr" ] = 0.0
1650+
16171651 else :
16181652 # still optimizng the text encoder
16191653 text_encoder_one .train ()
@@ -1628,7 +1662,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
16281662 with accelerator .accumulate (unet ):
16291663 pixel_values = batch ["pixel_values" ].to (dtype = vae .dtype )
16301664 prompts = batch ["prompts" ]
1631- print (prompts )
1665+ # print(prompts)
16321666 # encode batch prompts when custom prompts are provided for each image -
16331667 if train_dataset .custom_instance_prompts :
16341668 if freeze_text_encoder :
@@ -1801,7 +1835,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
18011835 f" { args .validation_prompt } ."
18021836 )
18031837 # create pipeline
1804- if not args . train_text_encoder :
1838+ if freeze_text_encoder :
18051839 text_encoder_one = text_encoder_cls_one .from_pretrained (
18061840 args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
18071841 )
@@ -1948,6 +1982,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19481982 images = images ,
19491983 base_model = args .pretrained_model_name_or_path ,
19501984 train_text_encoder = args .train_text_encoder ,
1985+ train_text_encoder_ti = args .train_text_encoder_ti ,
1986+ token_abstraction_dict = train_dataset .token_abstraction_dict ,
19511987 instance_prompt = args .instance_prompt ,
19521988 validation_prompt = args .validation_prompt ,
19531989 repo_folder = args .output_dir ,
0 commit comments