@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
12561256 key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
12571257 key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
12581258 key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
1259+ config_url = None
12591260
12601261 # model_type = "v1"
1261- config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
1262+ if config_files is not None and "v1" in config_files :
1263+ original_config_file = config_files ["v1" ]
1264+ else :
1265+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
12621266
12631267 if key_name_v2_1 in checkpoint and checkpoint [key_name_v2_1 ].shape [- 1 ] == 1024 :
12641268 # model_type = "v2"
1265- config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
1266-
1269+ if config_files is not None and "v2" in config_files :
1270+ original_config_file = config_files ["v2" ]
1271+ else :
1272+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
12671273 if global_step == 110000 :
12681274 # v2.1 needs to upcast attention
12691275 upcast_attention = True
12701276 elif key_name_sd_xl_base in checkpoint :
12711277 # only base xl has two text embedders
1272- config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
1278+ if config_files is not None and "xl" in config_files :
1279+ original_config_file = config_files ["xl" ]
1280+ else :
1281+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
12731282 elif key_name_sd_xl_refiner in checkpoint :
12741283 # only refiner xl has embedder and one text embedders
1275- config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1276-
1277- original_config_file = BytesIO (requests .get (config_url ).content )
1284+ if config_files is not None and "xl_refiner" in config_files :
1285+ original_config_file = config_files ["xl_refiner" ]
1286+ else :
1287+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1288+ if config_url is not None :
1289+ original_config_file = BytesIO (requests .get (config_url ).content )
12781290
12791291 original_config = OmegaConf .load (original_config_file )
12801292
0 commit comments