@@ -954,6 +954,25 @@ def stable_unclip_image_noising_components(
954954 return image_normalizer , image_noising_scheduler
955955
956956
957+ def convert_controlnet_checkpoint (
958+ checkpoint , original_config , checkpoint_path , image_size , upcast_attention , extract_ema
959+ ):
960+ ctrlnet_config = create_unet_diffusers_config (original_config , image_size = image_size , controlnet = True )
961+ ctrlnet_config ["upcast_attention" ] = upcast_attention
962+
963+ ctrlnet_config .pop ("sample_size" )
964+
965+ controlnet_model = ControlNetModel (** ctrlnet_config )
966+
967+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint (
968+ checkpoint , ctrlnet_config , path = checkpoint_path , extract_ema = extract_ema , controlnet = True
969+ )
970+
971+ controlnet_model .load_state_dict (converted_ctrl_checkpoint )
972+
973+ return controlnet_model
974+
975+
957976def download_from_original_stable_diffusion_ckpt (
958977 checkpoint_path : str ,
959978 original_config_file : str = None ,
@@ -1042,7 +1061,9 @@ def download_from_original_stable_diffusion_ckpt(
10421061 print ("global_step key not found in model" )
10431062 global_step = None
10441063
1045- if "state_dict" in checkpoint :
1064+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
1065+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
1066+ while "state_dict" in checkpoint :
10461067 checkpoint = checkpoint ["state_dict" ]
10471068
10481069 if original_config_file is None :
@@ -1084,6 +1105,14 @@ def download_from_original_stable_diffusion_ckpt(
10841105 if image_size is None :
10851106 image_size = 512
10861107
1108+ if controlnet is None :
1109+ controlnet = "control_stage_config" in original_config .model .params
1110+
1111+ if controlnet :
1112+ controlnet_model = convert_controlnet_checkpoint (
1113+ checkpoint , original_config , checkpoint_path , image_size , upcast_attention , extract_ema
1114+ )
1115+
10871116 num_train_timesteps = original_config .model .params .timesteps
10881117 beta_start = original_config .model .params .linear_start
10891118 beta_end = original_config .model .params .linear_end
@@ -1143,27 +1172,34 @@ def download_from_original_stable_diffusion_ckpt(
11431172 model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
11441173 logger .debug (f"no `model_type` given, `model_type` inferred as: { model_type } " )
11451174
1146- if controlnet is None :
1147- controlnet = "control_stage_config" in original_config .model .params
1148-
1149- if controlnet and model_type != "FrozenCLIPEmbedder" :
1150- raise ValueError ("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'" )
1151-
11521175 if model_type == "FrozenOpenCLIPEmbedder" :
11531176 text_model = convert_open_clip_checkpoint (checkpoint )
11541177 tokenizer = CLIPTokenizer .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "tokenizer" )
11551178
11561179 if stable_unclip is None :
1157- pipe = StableDiffusionPipeline (
1158- vae = vae ,
1159- text_encoder = text_model ,
1160- tokenizer = tokenizer ,
1161- unet = unet ,
1162- scheduler = scheduler ,
1163- safety_checker = None ,
1164- feature_extractor = None ,
1165- requires_safety_checker = False ,
1166- )
1180+ if controlnet :
1181+ pipe = StableDiffusionControlNetPipeline (
1182+ vae = vae ,
1183+ text_encoder = text_model ,
1184+ tokenizer = tokenizer ,
1185+ unet = unet ,
1186+ scheduler = scheduler ,
1187+ controlnet = controlnet_model ,
1188+ safety_checker = None ,
1189+ feature_extractor = None ,
1190+ requires_safety_checker = False ,
1191+ )
1192+ else :
1193+ pipe = StableDiffusionPipeline (
1194+ vae = vae ,
1195+ text_encoder = text_model ,
1196+ tokenizer = tokenizer ,
1197+ unet = unet ,
1198+ scheduler = scheduler ,
1199+ safety_checker = None ,
1200+ feature_extractor = None ,
1201+ requires_safety_checker = False ,
1202+ )
11671203 else :
11681204 image_normalizer , image_noising_scheduler = stable_unclip_image_noising_components (
11691205 original_config , clip_stats_path = clip_stats_path , device = device
@@ -1238,19 +1274,6 @@ def download_from_original_stable_diffusion_ckpt(
12381274 feature_extractor = AutoFeatureExtractor .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
12391275
12401276 if controlnet :
1241- # Convert the ControlNetModel model.
1242- ctrlnet_config = create_unet_diffusers_config (original_config , image_size = image_size , controlnet = True )
1243- ctrlnet_config ["upcast_attention" ] = upcast_attention
1244-
1245- ctrlnet_config .pop ("sample_size" )
1246-
1247- controlnet_model = ControlNetModel (** ctrlnet_config )
1248-
1249- converted_ctrl_checkpoint = convert_ldm_unet_checkpoint (
1250- checkpoint , ctrlnet_config , path = checkpoint_path , extract_ema = extract_ema , controlnet = True
1251- )
1252- controlnet_model .load_state_dict (converted_ctrl_checkpoint )
1253-
12541277 pipe = StableDiffusionControlNetPipeline (
12551278 vae = vae ,
12561279 text_encoder = text_model ,
@@ -1278,3 +1301,55 @@ def download_from_original_stable_diffusion_ckpt(
12781301 pipe = LDMTextToImagePipeline (vqvae = vae , bert = text_model , tokenizer = tokenizer , unet = unet , scheduler = scheduler )
12791302
12801303 return pipe
1304+
1305+
1306+ def download_controlnet_from_original_ckpt (
1307+ checkpoint_path : str ,
1308+ original_config_file : str ,
1309+ image_size : int = 512 ,
1310+ extract_ema : bool = False ,
1311+ num_in_channels : Optional [int ] = None ,
1312+ upcast_attention : Optional [bool ] = None ,
1313+ device : str = None ,
1314+ from_safetensors : bool = False ,
1315+ ) -> StableDiffusionPipeline :
1316+ if not is_omegaconf_available ():
1317+ raise ValueError (BACKENDS_MAPPING ["omegaconf" ][1 ])
1318+
1319+ from omegaconf import OmegaConf
1320+
1321+ if from_safetensors :
1322+ if not is_safetensors_available ():
1323+ raise ValueError (BACKENDS_MAPPING ["safetensors" ][1 ])
1324+
1325+ from safetensors import safe_open
1326+
1327+ checkpoint = {}
1328+ with safe_open (checkpoint_path , framework = "pt" , device = "cpu" ) as f :
1329+ for key in f .keys ():
1330+ checkpoint [key ] = f .get_tensor (key )
1331+ else :
1332+ if device is None :
1333+ device = "cuda" if torch .cuda .is_available () else "cpu"
1334+ checkpoint = torch .load (checkpoint_path , map_location = device )
1335+ else :
1336+ checkpoint = torch .load (checkpoint_path , map_location = device )
1337+
1338+ # NOTE: this while loop isn't great but this controlnet checkpoint has one additional
1339+ # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
1340+ while "state_dict" in checkpoint :
1341+ checkpoint = checkpoint ["state_dict" ]
1342+
1343+ original_config = OmegaConf .load (original_config_file )
1344+
1345+ if num_in_channels is not None :
1346+ original_config ["model" ]["params" ]["unet_config" ]["params" ]["in_channels" ] = num_in_channels
1347+
1348+ if "control_stage_config" not in original_config .model .params :
1349+ raise ValueError ("`control_stage_config` not present in original config" )
1350+
1351+ controlnet_model = convert_controlnet_checkpoint (
1352+ checkpoint , original_config , checkpoint_path , image_size , upcast_attention , extract_ema
1353+ )
1354+
1355+ return controlnet_model
0 commit comments