41
41
UNet2DConditionModel ,
42
42
)
43
43
from diffusers .pipelines .latent_diffusion .pipeline_latent_diffusion import LDMBertConfig , LDMBertModel
44
+ from diffusers .pipelines .paint_by_example import PaintByExampleImageEncoder , PaintByExamplePipeline
44
45
from diffusers .pipelines .stable_diffusion import StableDiffusionSafetyChecker
45
- from transformers import AutoFeatureExtractor , BertTokenizerFast , CLIPTextModel , CLIPTokenizer
46
+ from transformers import AutoFeatureExtractor , BertTokenizerFast , CLIPTextModel , CLIPTokenizer , CLIPVisionConfig
46
47
47
48
48
49
def shave_segments (path , n_shave_prefix_segments = 1 ):
@@ -647,6 +648,73 @@ def convert_ldm_clip_checkpoint(checkpoint):
647
648
return text_model
648
649
649
650
651
+ def convert_paint_by_example_checkpoint (checkpoint ):
652
+ config = CLIPVisionConfig .from_pretrained ("openai/clip-vit-large-patch14" )
653
+ model = PaintByExampleImageEncoder (config )
654
+
655
+ keys = list (checkpoint .keys ())
656
+
657
+ text_model_dict = {}
658
+
659
+ for key in keys :
660
+ if key .startswith ("cond_stage_model.transformer" ):
661
+ text_model_dict [key [len ("cond_stage_model.transformer." ) :]] = checkpoint [key ]
662
+
663
+ # load clip vision
664
+ model .model .load_state_dict (text_model_dict )
665
+
666
+ # load mapper
667
+ keys_mapper = {
668
+ k [len ("cond_stage_model.mapper.res" ) :]: v
669
+ for k , v in checkpoint .items ()
670
+ if k .startswith ("cond_stage_model.mapper" )
671
+ }
672
+
673
+ MAPPING = {
674
+ "attn.c_qkv" : ["attn1.to_q" , "attn1.to_k" , "attn1.to_v" ],
675
+ "attn.c_proj" : ["attn1.to_out.0" ],
676
+ "ln_1" : ["norm1" ],
677
+ "ln_2" : ["norm3" ],
678
+ "mlp.c_fc" : ["ff.net.0.proj" ],
679
+ "mlp.c_proj" : ["ff.net.2" ],
680
+ }
681
+
682
+ mapped_weights = {}
683
+ for key , value in keys_mapper .items ():
684
+ prefix = key [: len ("blocks.i" )]
685
+ suffix = key .split (prefix )[- 1 ].split ("." )[- 1 ]
686
+ name = key .split (prefix )[- 1 ].split (suffix )[0 ][1 :- 1 ]
687
+ mapped_names = MAPPING [name ]
688
+
689
+ num_splits = len (mapped_names )
690
+ for i , mapped_name in enumerate (mapped_names ):
691
+ new_name = "." .join ([prefix , mapped_name , suffix ])
692
+ shape = value .shape [0 ] // num_splits
693
+ mapped_weights [new_name ] = value [i * shape : (i + 1 ) * shape ]
694
+
695
+ model .mapper .load_state_dict (mapped_weights )
696
+
697
+ # load final layer norm
698
+ model .final_layer_norm .load_state_dict (
699
+ {
700
+ "bias" : checkpoint ["cond_stage_model.final_ln.bias" ],
701
+ "weight" : checkpoint ["cond_stage_model.final_ln.weight" ],
702
+ }
703
+ )
704
+
705
+ # load final proj
706
+ model .proj_out .load_state_dict (
707
+ {
708
+ "bias" : checkpoint ["proj_out.bias" ],
709
+ "weight" : checkpoint ["proj_out.weight" ],
710
+ }
711
+ )
712
+
713
+ # load uncond vector
714
+ model .uncond_vector .data = torch .nn .Parameter (checkpoint ["learnable_vector" ])
715
+ return model
716
+
717
+
650
718
def convert_open_clip_checkpoint (checkpoint ):
651
719
text_model = CLIPTextModel .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "text_encoder" )
652
720
@@ -676,12 +744,24 @@ def convert_open_clip_checkpoint(checkpoint):
676
744
type = str ,
677
745
help = "The YAML config file corresponding to the original architecture." ,
678
746
)
747
+ parser .add_argument (
748
+ "--num_in_channels" ,
749
+ default = None ,
750
+ type = int ,
751
+ help = "The number of input channels. If `None` number of input channels will be automatically inferred." ,
752
+ )
679
753
parser .add_argument (
680
754
"--scheduler_type" ,
681
755
default = "pndm" ,
682
756
type = str ,
683
757
help = "Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']" ,
684
758
)
759
+ parser .add_argument (
760
+ "--pipeline_type" ,
761
+ default = None ,
762
+ type = str ,
763
+ help = "The pipeline type. If `None` pipeline will be automatically inferred." ,
764
+ )
685
765
parser .add_argument (
686
766
"--image_size" ,
687
767
default = None ,
@@ -737,6 +817,9 @@ def convert_open_clip_checkpoint(checkpoint):
737
817
738
818
original_config = OmegaConf .load (args .original_config_file )
739
819
820
+ if args .num_in_channels is not None :
821
+ original_config ["model" ]["params" ]["unet_config" ]["params" ]["in_channels" ] = args .num_in_channels
822
+
740
823
if (
741
824
"parameterization" in original_config ["model" ]["params" ]
742
825
and original_config ["model" ]["params" ]["parameterization" ] == "v"
@@ -806,8 +889,11 @@ def convert_open_clip_checkpoint(checkpoint):
806
889
vae .load_state_dict (converted_vae_checkpoint )
807
890
808
891
# Convert the text model.
809
- text_model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
810
- if text_model_type == "FrozenOpenCLIPEmbedder" :
892
+ model_type = args .pipeline_type
893
+ if model_type is None :
894
+ model_type = original_config .model .params .cond_stage_config .target .split ("." )[- 1 ]
895
+
896
+ if model_type == "FrozenOpenCLIPEmbedder" :
811
897
text_model = convert_open_clip_checkpoint (checkpoint )
812
898
tokenizer = CLIPTokenizer .from_pretrained ("stabilityai/stable-diffusion-2" , subfolder = "tokenizer" )
813
899
pipe = StableDiffusionPipeline (
@@ -820,7 +906,19 @@ def convert_open_clip_checkpoint(checkpoint):
820
906
feature_extractor = None ,
821
907
requires_safety_checker = False ,
822
908
)
823
- elif text_model_type == "FrozenCLIPEmbedder" :
909
+ elif model_type == "PaintByExample" :
910
+ vision_model = convert_paint_by_example_checkpoint (checkpoint )
911
+ tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
912
+ feature_extractor = AutoFeatureExtractor .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
913
+ pipe = PaintByExamplePipeline (
914
+ vae = vae ,
915
+ image_encoder = vision_model ,
916
+ unet = unet ,
917
+ scheduler = scheduler ,
918
+ safety_checker = None ,
919
+ feature_extractor = feature_extractor ,
920
+ )
921
+ elif model_type == "FrozenCLIPEmbedder" :
824
922
text_model = convert_ldm_clip_checkpoint (checkpoint )
825
923
tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
826
924
safety_checker = StableDiffusionSafetyChecker .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
0 commit comments