6060logger = get_logger (__name__ )
6161
6262
63+ def image_grid (imgs , rows , cols ):
64+ assert len (imgs ) == rows * cols
65+
66+ w , h = imgs [0 ].size
67+ grid = Image .new ("RGB" , size = (cols * w , rows * h ))
68+
69+ for i , img in enumerate (imgs ):
70+ grid .paste (img , box = (i % cols * w , i // cols * h ))
71+ return grid
72+
73+
6374def log_validation (vae , text_encoder , tokenizer , unet , controlnet , args , accelerator , weight_dtype , step ):
6475 logger .info ("Running validation... " )
6576
@@ -156,6 +167,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
156167 else :
157168 logger .warn (f"image logging not implemented for { tracker .name } " )
158169
170+ return image_logs
171+
159172
160173def import_model_class_from_model_name_or_path (pretrained_model_name_or_path : str , revision : str ):
161174 text_encoder_config = PretrainedConfig .from_pretrained (
@@ -177,6 +190,43 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
177190 raise ValueError (f"{ model_class } is not supported." )
178191
179192
193+ def save_model_card (repo_id : str , image_logs = None , base_model = str , repo_folder = None ):
194+ img_str = ""
195+ if image_logs is not None :
196+ img_str = "You can find some example images below.\n "
197+ for i , log in enumerate (image_logs ):
198+ images = log ["images" ]
199+ validation_prompt = log ["validation_prompt" ]
200+ validation_image = log ["validation_image" ]
201+ validation_image .save (os .path .join (repo_folder , "image_control.png" ))
202+ img_str += f"prompt: { validation_prompt } \n "
203+ images = [validation_image ] + images
204+ image_grid (images , 1 , len (images )).save (os .path .join (repo_folder , f"images_{ i } .png" ))
205+ img_str += f"\n "
206+
207+ yaml = f"""
208+ ---
209+ license: creativeml-openrail-m
210+ base_model: { base_model }
211+ tags:
212+ - stable-diffusion
213+ - stable-diffusion-diffusers
214+ - text-to-image
215+ - diffusers
216+ - controlnet
217+ inference: true
218+ ---
219+ """
220+ model_card = f"""
221+ # controlnet-{ repo_id }
222+
223+ These are controlnet weights trained on { base_model } with new type of conditioning.
224+ { img_str }
225+ """
226+ with open (os .path .join (repo_folder , "README.md" ), "w" ) as f :
227+ f .write (yaml + model_card )
228+
229+
180230def parse_args (input_args = None ):
181231 parser = argparse .ArgumentParser (description = "Simple example of a ControlNet training script." )
182232 parser .add_argument (
@@ -943,6 +993,7 @@ def load_model_hook(models, input_dir):
943993 disable = not accelerator .is_local_main_process ,
944994 )
945995
996+ image_logs = None
946997 for epoch in range (first_epoch , args .num_train_epochs ):
947998 for step , batch in enumerate (train_dataloader ):
948999 with accelerator .accumulate (controlnet ):
@@ -1014,7 +1065,7 @@ def load_model_hook(models, input_dir):
10141065 logger .info (f"Saved state to { save_path } " )
10151066
10161067 if args .validation_prompt is not None and global_step % args .validation_steps == 0 :
1017- log_validation (
1068+ image_logs = log_validation (
10181069 vae ,
10191070 text_encoder ,
10201071 tokenizer ,
@@ -1040,6 +1091,12 @@ def load_model_hook(models, input_dir):
10401091 controlnet .save_pretrained (args .output_dir )
10411092
10421093 if args .push_to_hub :
1094+ save_model_card (
1095+ repo_id ,
1096+ image_logs = image_logs ,
1097+ base_model = args .pretrained_model_name_or_path ,
1098+ repo_folder = args .output_dir ,
1099+ )
10431100 upload_folder (
10441101 repo_id = repo_id ,
10451102 folder_path = args .output_dir ,
0 commit comments