33
33
DPMSolverMultistepScheduler ,
34
34
EulerAncestralDiscreteScheduler ,
35
35
EulerDiscreteScheduler ,
36
+ HeunDiscreteScheduler ,
36
37
LDMTextToImagePipeline ,
37
38
LMSDiscreteScheduler ,
38
39
PNDMScheduler ,
@@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):
232
233
233
234
vae_scale_factor = 2 ** (len (vae_params .ch_mult ) - 1 )
234
235
236
+ head_dim = unet_params .num_heads if "num_heads" in unet_params else None
237
+ use_linear_projection = (
238
+ unet_params .use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
239
+ )
240
+ if use_linear_projection :
241
+ # stable diffusion 2-base-512 and 2-768
242
+ if head_dim is None :
243
+ head_dim = [5 , 10 , 20 , 20 ]
244
+
235
245
config = dict (
236
246
sample_size = image_size // vae_scale_factor ,
237
247
in_channels = unet_params .in_channels ,
@@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
241
251
block_out_channels = tuple (block_out_channels ),
242
252
layers_per_block = unet_params .num_res_blocks ,
243
253
cross_attention_dim = unet_params .context_dim ,
244
- attention_head_dim = unet_params .num_heads ,
254
+ attention_head_dim = head_dim ,
255
+ use_linear_projection = use_linear_projection ,
245
256
)
246
257
247
258
return config
@@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
636
647
return text_model
637
648
638
649
650
+ def convert_open_clip_checkpoint (checkpoint ):
651
+ text_model = CLIPTextModel .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "text_encoder" )
652
+
653
+ # SKIP for now - need openclip -> HF conversion script here
654
+ # keys = list(checkpoint.keys())
655
+ #
656
+ # text_model_dict = {}
657
+ # for key in keys:
658
+ # if key.startswith("cond_stage_model.model.transformer"):
659
+ # text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
660
+ #
661
+ # text_model.load_state_dict(text_model_dict)
662
+
663
+ return text_model
664
+
665
+
639
666
if __name__ == "__main__" :
640
667
parser = argparse .ArgumentParser ()
641
668
@@ -657,13 +684,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
657
684
)
658
685
parser .add_argument (
659
686
"--image_size" ,
660
- default = 512 ,
687
+ default = None ,
661
688
type = int ,
662
689
help = (
663
690
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
664
691
" Base. Use 768 for Stable Diffusion v2."
665
692
),
666
693
)
694
+ parser .add_argument (
695
+ "--prediction_type" ,
696
+ default = None ,
697
+ type = int ,
698
+ help = (
699
+ "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
700
+ " Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
701
+ ),
702
+ )
667
703
parser .add_argument (
668
704
"--extract_ema" ,
669
705
action = "store_true" ,
@@ -674,73 +710,117 @@ def convert_ldm_clip_checkpoint(checkpoint):
674
710
),
675
711
)
676
712
parser .add_argument ("--dump_path" , default = None , type = str , required = True , help = "Path to the output model." )
677
-
678
713
args = parser .parse_args ()
679
714
715
+ image_size = args .image_size
716
+ prediction_type = args .prediction_type
717
+
718
+ checkpoint = torch .load (args .checkpoint_path )
719
+ global_step = checkpoint ["global_step" ]
720
+ checkpoint = checkpoint ["state_dict" ]
721
+
680
722
if args .original_config_file is None :
681
- os .system (
682
- "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
683
- )
684
- args .original_config_file = "./v1-inference.yaml"
723
+ key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
724
+
725
+ if key_name in checkpoint and checkpoint [key_name ].shape [- 1 ] == 1024 :
726
+ # model_type = "v2"
727
+ os .system (
728
+ "wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
729
+ )
730
+ args .original_config_file = "./v2-inference-v.yaml"
731
+ else :
732
+ # model_type = "v1"
733
+ os .system (
734
+ "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
735
+ )
736
+ args .original_config_file = "./v1-inference.yaml"
685
737
686
738
original_config = OmegaConf .load (args .original_config_file )
687
739
688
- checkpoint = torch .load (args .checkpoint_path )
689
- checkpoint = checkpoint ["state_dict" ]
740
+ if (
741
+ "parameterization" in original_config ["model" ]["params" ]
742
+ and original_config ["model" ]["params" ]["parameterization" ] == "v"
743
+ ):
744
+ if prediction_type is None :
745
+ # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
746
+ # as it relies on a brittle global step parameter here
747
+ prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
748
+ if image_size is None :
749
+ # NOTE: For stable diffusion 2 base one has to pass `image_size==512`
750
+ # as it relies on a brittle global step parameter here
751
+ image_size = 512 if global_step == 875000 else 768
752
+ else :
753
+ if prediction_type is None :
754
+ prediction_type = "epsilon"
755
+ if image_size is None :
756
+ image_size = 512
690
757
691
758
num_train_timesteps = original_config .model .params .timesteps
692
759
beta_start = original_config .model .params .linear_start
693
760
beta_end = original_config .model .params .linear_end
761
+
762
+ scheduler = DDIMScheduler (
763
+ beta_end = beta_end ,
764
+ beta_schedule = "scaled_linear" ,
765
+ beta_start = beta_start ,
766
+ num_train_timesteps = num_train_timesteps ,
767
+ steps_offset = 1 ,
768
+ clip_sample = False ,
769
+ set_alpha_to_one = False ,
770
+ prediction_type = prediction_type ,
771
+ )
694
772
if args .scheduler_type == "pndm" :
695
- scheduler = PNDMScheduler (
696
- beta_end = beta_end ,
697
- beta_schedule = "scaled_linear" ,
698
- beta_start = beta_start ,
699
- num_train_timesteps = num_train_timesteps ,
700
- skip_prk_steps = True ,
701
- )
773
+ config = dict (scheduler .config )
774
+ config ["skip_prk_steps" ] = True
775
+ scheduler = PNDMScheduler .from_config (config )
702
776
elif args .scheduler_type == "lms" :
703
- scheduler = LMSDiscreteScheduler (beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear" )
777
+ scheduler = LMSDiscreteScheduler .from_config (scheduler .config )
778
+ elif args .scheduler_type == "heun" :
779
+ scheduler = HeunDiscreteScheduler .from_config (scheduler .config )
704
780
elif args .scheduler_type == "euler" :
705
- scheduler = EulerDiscreteScheduler ( beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear" )
781
+ scheduler = EulerDiscreteScheduler . from_config ( scheduler . config )
706
782
elif args .scheduler_type == "euler-ancestral" :
707
- scheduler = EulerAncestralDiscreteScheduler (
708
- beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear"
709
- )
783
+ scheduler = EulerAncestralDiscreteScheduler .from_config (scheduler .config )
710
784
elif args .scheduler_type == "dpm" :
711
- scheduler = DPMSolverMultistepScheduler (
712
- beta_start = beta_start , beta_end = beta_end , beta_schedule = "scaled_linear"
713
- )
785
+ scheduler = DPMSolverMultistepScheduler .from_config (scheduler .config )
714
786
elif args .scheduler_type == "ddim" :
715
- scheduler = DDIMScheduler (
716
- beta_start = beta_start ,
717
- beta_end = beta_end ,
718
- beta_schedule = "scaled_linear" ,
719
- clip_sample = False ,
720
- set_alpha_to_one = False ,
721
- )
787
+ scheduler = scheduler
722
788
else :
723
789
raise ValueError (f"Scheduler of type { args .scheduler_type } doesn't exist!" )
724
790
725
791
# Convert the UNet2DConditionModel model.
726
- unet_config = create_unet_diffusers_config (original_config , image_size = args .image_size )
792
+ unet_config = create_unet_diffusers_config (original_config , image_size = image_size )
793
+ unet = UNet2DConditionModel (** unet_config )
794
+
727
795
converted_unet_checkpoint = convert_ldm_unet_checkpoint (
728
796
checkpoint , unet_config , path = args .checkpoint_path , extract_ema = args .extract_ema
729
797
)
730
798
731
- unet = UNet2DConditionModel (** unet_config )
732
799
unet .load_state_dict (converted_unet_checkpoint )
733
800
734
801
# Convert the VAE model.
735
- vae_config = create_vae_diffusers_config (original_config , image_size = args . image_size )
802
+ vae_config = create_vae_diffusers_config (original_config , image_size = image_size )
736
803
converted_vae_checkpoint = convert_ldm_vae_checkpoint (checkpoint , vae_config )
737
804
738
805
vae = AutoencoderKL (** vae_config )
739
806
vae .load_state_dict (converted_vae_checkpoint )
740
807
741
808
# Convert the text model.
742
809
text_model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
743
- if text_model_type == "FrozenCLIPEmbedder" :
810
+ if text_model_type == "FrozenOpenCLIPEmbedder" :
811
+ text_model = convert_open_clip_checkpoint (checkpoint )
812
+ tokenizer = CLIPTokenizer .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "tokenizer" )
813
+ pipe = StableDiffusionPipeline (
814
+ vae = vae ,
815
+ text_encoder = text_model ,
816
+ tokenizer = tokenizer ,
817
+ unet = unet ,
818
+ scheduler = scheduler ,
819
+ safety_checker = None ,
820
+ feature_extractor = None ,
821
+ requires_safety_checker = False ,
822
+ )
823
+ elif text_model_type == "FrozenCLIPEmbedder" :
744
824
text_model = convert_ldm_clip_checkpoint (checkpoint )
745
825
tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
746
826
safety_checker = StableDiffusionSafetyChecker .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
0 commit comments