Skip to content

Commit d2dc4de

Browse files
authored
Handle missing global_step key in scripts/convert_original_stable_diffusion_to_diffusers.py (huggingface#1612)
handle missing global_step key and don't download config if it already exists
1 parent ded3299 commit d2dc4de

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -854,28 +854,38 @@ def convert_open_clip_checkpoint(checkpoint):
854854
prediction_type = args.prediction_type
855855

856856
checkpoint = torch.load(args.checkpoint_path)
857-
global_step = checkpoint["global_step"]
857+
858+
# Sometimes models don't have the global_step item
859+
if "global_step" in checkpoint:
860+
global_step = checkpoint["global_step"]
861+
else:
862+
print("global_step key not found in model")
863+
global_step = None
858864
checkpoint = checkpoint["state_dict"]
859865

860866
upcast_attention = False
861867
if args.original_config_file is None:
862868
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
863869

864870
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
865-
# model_type = "v2"
866-
os.system(
867-
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
868-
)
871+
if not os.path.isfile("v2-inference-v.yaml"):
872+
# model_type = "v2"
873+
os.system(
874+
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
875+
" -O v2-inference-v.yaml"
876+
)
869877
args.original_config_file = "./v2-inference-v.yaml"
870878

871879
if global_step == 110000:
872880
# v2.1 needs to upcast attention
873881
upcast_attention = True
874882
else:
875-
# model_type = "v1"
876-
os.system(
877-
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
878-
)
883+
if not os.path.isfile("v1-inference.yaml"):
884+
# model_type = "v1"
885+
os.system(
886+
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
887+
" -O v1-inference.yaml"
888+
)
879889
args.original_config_file = "./v1-inference.yaml"
880890

881891
original_config = OmegaConf.load(args.original_config_file)

0 commit comments

Comments
 (0)