@@ -854,28 +854,38 @@ def convert_open_clip_checkpoint(checkpoint):
854854 prediction_type = args .prediction_type
855855
856856 checkpoint = torch .load (args .checkpoint_path )
857- global_step = checkpoint ["global_step" ]
857+
858+ # Sometimes models don't have the global_step item
859+ if "global_step" in checkpoint :
860+ global_step = checkpoint ["global_step" ]
861+ else :
862+ print ("global_step key not found in model" )
863+ global_step = None
858864 checkpoint = checkpoint ["state_dict" ]
859865
860866 upcast_attention = False
861867 if args .original_config_file is None :
862868 key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
863869
864870 if key_name in checkpoint and checkpoint [key_name ].shape [- 1 ] == 1024 :
865- # model_type = "v2"
866- os .system (
867- "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
868- )
871+ if not os .path .isfile ("v2-inference-v.yaml" ):
872+ # model_type = "v2"
873+ os .system (
874+ "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
875+ " -O v2-inference-v.yaml"
876+ )
869877 args .original_config_file = "./v2-inference-v.yaml"
870878
871879 if global_step == 110000 :
872880 # v2.1 needs to upcast attention
873881 upcast_attention = True
874882 else :
875- # model_type = "v1"
876- os .system (
877- "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
878- )
883+ if not os .path .isfile ("v1-inference.yaml" ):
884+ # model_type = "v1"
885+ os .system (
886+ "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
887+ " -O v1-inference.yaml"
888+ )
879889 args .original_config_file = "./v1-inference.yaml"
880890
881891 original_config = OmegaConf .load (args .original_config_file )
0 commit comments