Skip to content

Commit a28acb5

Browse files
controlnet sd 2.1 checkpoint conversions (huggingface#2593)
* controlnet sd 2.1 checkpoint conversions * remove global_step -> make config file mandatory
1 parent f1ab955 commit a28acb5

File tree

2 files changed

+196
-30
lines changed

2 files changed

+196
-30
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# coding=utf-8
2+
# Copyright 2023 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """
16+
17+
import argparse
18+
19+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
20+
21+
22+
if __name__ == "__main__":
23+
parser = argparse.ArgumentParser()
24+
25+
parser.add_argument(
26+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
27+
)
28+
parser.add_argument(
29+
"--original_config_file",
30+
type=str,
31+
required=True,
32+
help="The YAML config file corresponding to the original architecture.",
33+
)
34+
parser.add_argument(
35+
"--num_in_channels",
36+
default=None,
37+
type=int,
38+
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
39+
)
40+
parser.add_argument(
41+
"--image_size",
42+
default=512,
43+
type=int,
44+
help=(
45+
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
46+
" Base. Use 768 for Stable Diffusion v2."
47+
),
48+
)
49+
parser.add_argument(
50+
"--extract_ema",
51+
action="store_true",
52+
help=(
53+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
54+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
55+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
56+
),
57+
)
58+
parser.add_argument(
59+
"--upcast_attention",
60+
action="store_true",
61+
help=(
62+
"Whether the attention computation should always be upcasted. This is necessary when running stable"
63+
" diffusion 2.1."
64+
),
65+
)
66+
parser.add_argument(
67+
"--from_safetensors",
68+
action="store_true",
69+
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
70+
)
71+
parser.add_argument(
72+
"--to_safetensors",
73+
action="store_true",
74+
help="Whether to store pipeline in safetensors format or not.",
75+
)
76+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
77+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
78+
args = parser.parse_args()
79+
80+
controlnet = download_controlnet_from_original_ckpt(
81+
checkpoint_path=args.checkpoint_path,
82+
original_config_file=args.original_config_file,
83+
image_size=args.image_size,
84+
extract_ema=args.extract_ema,
85+
num_in_channels=args.num_in_channels,
86+
upcast_attention=args.upcast_attention,
87+
from_safetensors=args.from_safetensors,
88+
device=args.device,
89+
)
90+
91+
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 105 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,25 @@ def stable_unclip_image_noising_components(
954954
return image_normalizer, image_noising_scheduler
955955

956956

957+
def convert_controlnet_checkpoint(
958+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
959+
):
960+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
961+
ctrlnet_config["upcast_attention"] = upcast_attention
962+
963+
ctrlnet_config.pop("sample_size")
964+
965+
controlnet_model = ControlNetModel(**ctrlnet_config)
966+
967+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
968+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
969+
)
970+
971+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
972+
973+
return controlnet_model
974+
975+
957976
def download_from_original_stable_diffusion_ckpt(
958977
checkpoint_path: str,
959978
original_config_file: str = None,
@@ -1042,7 +1061,9 @@ def download_from_original_stable_diffusion_ckpt(
10421061
print("global_step key not found in model")
10431062
global_step = None
10441063

1045-
if "state_dict" in checkpoint:
1064+
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
1065+
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
1066+
while "state_dict" in checkpoint:
10461067
checkpoint = checkpoint["state_dict"]
10471068

10481069
if original_config_file is None:
@@ -1084,6 +1105,14 @@ def download_from_original_stable_diffusion_ckpt(
10841105
if image_size is None:
10851106
image_size = 512
10861107

1108+
if controlnet is None:
1109+
controlnet = "control_stage_config" in original_config.model.params
1110+
1111+
if controlnet:
1112+
controlnet_model = convert_controlnet_checkpoint(
1113+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
1114+
)
1115+
10871116
num_train_timesteps = original_config.model.params.timesteps
10881117
beta_start = original_config.model.params.linear_start
10891118
beta_end = original_config.model.params.linear_end
@@ -1143,27 +1172,34 @@ def download_from_original_stable_diffusion_ckpt(
11431172
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
11441173
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
11451174

1146-
if controlnet is None:
1147-
controlnet = "control_stage_config" in original_config.model.params
1148-
1149-
if controlnet and model_type != "FrozenCLIPEmbedder":
1150-
raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'")
1151-
11521175
if model_type == "FrozenOpenCLIPEmbedder":
11531176
text_model = convert_open_clip_checkpoint(checkpoint)
11541177
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
11551178

11561179
if stable_unclip is None:
1157-
pipe = StableDiffusionPipeline(
1158-
vae=vae,
1159-
text_encoder=text_model,
1160-
tokenizer=tokenizer,
1161-
unet=unet,
1162-
scheduler=scheduler,
1163-
safety_checker=None,
1164-
feature_extractor=None,
1165-
requires_safety_checker=False,
1166-
)
1180+
if controlnet:
1181+
pipe = StableDiffusionControlNetPipeline(
1182+
vae=vae,
1183+
text_encoder=text_model,
1184+
tokenizer=tokenizer,
1185+
unet=unet,
1186+
scheduler=scheduler,
1187+
controlnet=controlnet_model,
1188+
safety_checker=None,
1189+
feature_extractor=None,
1190+
requires_safety_checker=False,
1191+
)
1192+
else:
1193+
pipe = StableDiffusionPipeline(
1194+
vae=vae,
1195+
text_encoder=text_model,
1196+
tokenizer=tokenizer,
1197+
unet=unet,
1198+
scheduler=scheduler,
1199+
safety_checker=None,
1200+
feature_extractor=None,
1201+
requires_safety_checker=False,
1202+
)
11671203
else:
11681204
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
11691205
original_config, clip_stats_path=clip_stats_path, device=device
@@ -1238,19 +1274,6 @@ def download_from_original_stable_diffusion_ckpt(
12381274
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
12391275

12401276
if controlnet:
1241-
# Convert the ControlNetModel model.
1242-
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
1243-
ctrlnet_config["upcast_attention"] = upcast_attention
1244-
1245-
ctrlnet_config.pop("sample_size")
1246-
1247-
controlnet_model = ControlNetModel(**ctrlnet_config)
1248-
1249-
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
1250-
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
1251-
)
1252-
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
1253-
12541277
pipe = StableDiffusionControlNetPipeline(
12551278
vae=vae,
12561279
text_encoder=text_model,
@@ -1278,3 +1301,55 @@ def download_from_original_stable_diffusion_ckpt(
12781301
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
12791302

12801303
return pipe
1304+
1305+
1306+
def download_controlnet_from_original_ckpt(
1307+
checkpoint_path: str,
1308+
original_config_file: str,
1309+
image_size: int = 512,
1310+
extract_ema: bool = False,
1311+
num_in_channels: Optional[int] = None,
1312+
upcast_attention: Optional[bool] = None,
1313+
device: str = None,
1314+
from_safetensors: bool = False,
1315+
) -> StableDiffusionPipeline:
1316+
if not is_omegaconf_available():
1317+
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
1318+
1319+
from omegaconf import OmegaConf
1320+
1321+
if from_safetensors:
1322+
if not is_safetensors_available():
1323+
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
1324+
1325+
from safetensors import safe_open
1326+
1327+
checkpoint = {}
1328+
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
1329+
for key in f.keys():
1330+
checkpoint[key] = f.get_tensor(key)
1331+
else:
1332+
if device is None:
1333+
device = "cuda" if torch.cuda.is_available() else "cpu"
1334+
checkpoint = torch.load(checkpoint_path, map_location=device)
1335+
else:
1336+
checkpoint = torch.load(checkpoint_path, map_location=device)
1337+
1338+
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
1339+
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
1340+
while "state_dict" in checkpoint:
1341+
checkpoint = checkpoint["state_dict"]
1342+
1343+
original_config = OmegaConf.load(original_config_file)
1344+
1345+
if num_in_channels is not None:
1346+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
1347+
1348+
if "control_stage_config" not in original_config.model.params:
1349+
raise ValueError("`control_stage_config` not present in original config")
1350+
1351+
controlnet_model = convert_controlnet_checkpoint(
1352+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
1353+
)
1354+
1355+
return controlnet_model

0 commit comments

Comments
 (0)