Skip to content

Commit f21415d

Browse files
Update conversion script to correctly handle SD 2 (huggingface#1511)
* Conversion SD 2 * finish
1 parent 22b9cb0 commit f21415d

File tree

1 file changed

+115
-35
lines changed

1 file changed

+115
-35
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 115 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DPMSolverMultistepScheduler,
3434
EulerAncestralDiscreteScheduler,
3535
EulerDiscreteScheduler,
36+
HeunDiscreteScheduler,
3637
LDMTextToImagePipeline,
3738
LMSDiscreteScheduler,
3839
PNDMScheduler,
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
232233

233234
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
234235

236+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
237+
use_linear_projection = (
238+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
239+
)
240+
if use_linear_projection:
241+
# stable diffusion 2-base-512 and 2-768
242+
if head_dim is None:
243+
head_dim = [5, 10, 20, 20]
244+
235245
config = dict(
236246
sample_size=image_size // vae_scale_factor,
237247
in_channels=unet_params.in_channels,
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
241251
block_out_channels=tuple(block_out_channels),
242252
layers_per_block=unet_params.num_res_blocks,
243253
cross_attention_dim=unet_params.context_dim,
244-
attention_head_dim=unet_params.num_heads,
254+
attention_head_dim=head_dim,
255+
use_linear_projection=use_linear_projection,
245256
)
246257

247258
return config
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
636647
return text_model
637648

638649

650+
def convert_open_clip_checkpoint(checkpoint):
651+
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
652+
653+
# SKIP for now - need openclip -> HF conversion script here
654+
# keys = list(checkpoint.keys())
655+
#
656+
# text_model_dict = {}
657+
# for key in keys:
658+
# if key.startswith("cond_stage_model.model.transformer"):
659+
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
660+
#
661+
# text_model.load_state_dict(text_model_dict)
662+
663+
return text_model
664+
665+
639666
if __name__ == "__main__":
640667
parser = argparse.ArgumentParser()
641668

@@ -657,13 +684,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
657684
)
658685
parser.add_argument(
659686
"--image_size",
660-
default=512,
687+
default=None,
661688
type=int,
662689
help=(
663690
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
664691
" Base. Use 768 for Stable Diffusion v2."
665692
),
666693
)
694+
parser.add_argument(
695+
"--prediction_type",
696+
default=None,
697+
type=int,
698+
help=(
699+
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
700+
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
701+
),
702+
)
667703
parser.add_argument(
668704
"--extract_ema",
669705
action="store_true",
@@ -674,73 +710,117 @@ def convert_ldm_clip_checkpoint(checkpoint):
674710
),
675711
)
676712
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
677-
678713
args = parser.parse_args()
679714

715+
image_size = args.image_size
716+
prediction_type = args.prediction_type
717+
718+
checkpoint = torch.load(args.checkpoint_path)
719+
global_step = checkpoint["global_step"]
720+
checkpoint = checkpoint["state_dict"]
721+
680722
if args.original_config_file is None:
681-
os.system(
682-
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
683-
)
684-
args.original_config_file = "./v1-inference.yaml"
723+
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
724+
725+
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
726+
# model_type = "v2"
727+
os.system(
728+
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
729+
)
730+
args.original_config_file = "./v2-inference-v.yaml"
731+
else:
732+
# model_type = "v1"
733+
os.system(
734+
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
735+
)
736+
args.original_config_file = "./v1-inference.yaml"
685737

686738
original_config = OmegaConf.load(args.original_config_file)
687739

688-
checkpoint = torch.load(args.checkpoint_path)
689-
checkpoint = checkpoint["state_dict"]
740+
if (
741+
"parameterization" in original_config["model"]["params"]
742+
and original_config["model"]["params"]["parameterization"] == "v"
743+
):
744+
if prediction_type is None:
745+
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
746+
# as it relies on a brittle global step parameter here
747+
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
748+
if image_size is None:
749+
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
750+
# as it relies on a brittle global step parameter here
751+
image_size = 512 if global_step == 875000 else 768
752+
else:
753+
if prediction_type is None:
754+
prediction_type = "epsilon"
755+
if image_size is None:
756+
image_size = 512
690757

691758
num_train_timesteps = original_config.model.params.timesteps
692759
beta_start = original_config.model.params.linear_start
693760
beta_end = original_config.model.params.linear_end
761+
762+
scheduler = DDIMScheduler(
763+
beta_end=beta_end,
764+
beta_schedule="scaled_linear",
765+
beta_start=beta_start,
766+
num_train_timesteps=num_train_timesteps,
767+
steps_offset=1,
768+
clip_sample=False,
769+
set_alpha_to_one=False,
770+
prediction_type=prediction_type,
771+
)
694772
if args.scheduler_type == "pndm":
695-
scheduler = PNDMScheduler(
696-
beta_end=beta_end,
697-
beta_schedule="scaled_linear",
698-
beta_start=beta_start,
699-
num_train_timesteps=num_train_timesteps,
700-
skip_prk_steps=True,
701-
)
773+
config = dict(scheduler.config)
774+
config["skip_prk_steps"] = True
775+
scheduler = PNDMScheduler.from_config(config)
702776
elif args.scheduler_type == "lms":
703-
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
777+
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
778+
elif args.scheduler_type == "heun":
779+
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
704780
elif args.scheduler_type == "euler":
705-
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
781+
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
706782
elif args.scheduler_type == "euler-ancestral":
707-
scheduler = EulerAncestralDiscreteScheduler(
708-
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
709-
)
783+
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
710784
elif args.scheduler_type == "dpm":
711-
scheduler = DPMSolverMultistepScheduler(
712-
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
713-
)
785+
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
714786
elif args.scheduler_type == "ddim":
715-
scheduler = DDIMScheduler(
716-
beta_start=beta_start,
717-
beta_end=beta_end,
718-
beta_schedule="scaled_linear",
719-
clip_sample=False,
720-
set_alpha_to_one=False,
721-
)
787+
scheduler = scheduler
722788
else:
723789
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
724790

725791
# Convert the UNet2DConditionModel model.
726-
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
792+
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
793+
unet = UNet2DConditionModel(**unet_config)
794+
727795
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
728796
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
729797
)
730798

731-
unet = UNet2DConditionModel(**unet_config)
732799
unet.load_state_dict(converted_unet_checkpoint)
733800

734801
# Convert the VAE model.
735-
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
802+
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
736803
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
737804

738805
vae = AutoencoderKL(**vae_config)
739806
vae.load_state_dict(converted_vae_checkpoint)
740807

741808
# Convert the text model.
742809
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
743-
if text_model_type == "FrozenCLIPEmbedder":
810+
if text_model_type == "FrozenOpenCLIPEmbedder":
811+
text_model = convert_open_clip_checkpoint(checkpoint)
812+
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
813+
pipe = StableDiffusionPipeline(
814+
vae=vae,
815+
text_encoder=text_model,
816+
tokenizer=tokenizer,
817+
unet=unet,
818+
scheduler=scheduler,
819+
safety_checker=None,
820+
feature_extractor=None,
821+
requires_safety_checker=False,
822+
)
823+
elif text_model_type == "FrozenCLIPEmbedder":
744824
text_model = convert_ldm_clip_checkpoint(checkpoint)
745825
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
746826
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

0 commit comments

Comments
 (0)