@@ -734,8 +734,12 @@ def _copy_layers(hf_layers, pt_layers):
734734 return hf_model
735735
736736
737- def convert_ldm_clip_checkpoint (checkpoint , local_files_only = False ):
738- text_model = CLIPTextModel .from_pretrained ("openai/clip-vit-large-patch14" , local_files_only = local_files_only )
737+ def convert_ldm_clip_checkpoint (checkpoint , local_files_only = False , text_encoder = None ):
738+ text_model = (
739+ CLIPTextModel .from_pretrained ("openai/clip-vit-large-patch14" , local_files_only = local_files_only )
740+ if text_encoder is None
741+ else text_encoder
742+ )
739743
740744 keys = list (checkpoint .keys ())
741745
@@ -1025,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt(
10251029 load_safety_checker : bool = True ,
10261030 pipeline_class : DiffusionPipeline = None ,
10271031 local_files_only = False ,
1032+ text_encoder = None ,
1033+ tokenizer = None ,
10281034) -> DiffusionPipeline :
10291035 """
10301036 Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1072,6 +1078,15 @@ def download_from_original_stable_diffusion_ckpt(
10721078 The pipeline class to use. Pass `None` to determine automatically.
10731079 local_files_only (`bool`, *optional*, defaults to `False`):
10741080 Whether or not to only look at local files (i.e., do not try to download the model).
1081+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1082+ An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
1083+ to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1084+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
1085+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1086+ An instance of
1087+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1088+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
1089+ needed.
10751090 return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
10761091 """
10771092
@@ -1327,8 +1342,10 @@ def download_from_original_stable_diffusion_ckpt(
13271342 feature_extractor = feature_extractor ,
13281343 )
13291344 elif model_type == "FrozenCLIPEmbedder" :
1330- text_model = convert_ldm_clip_checkpoint (checkpoint , local_files_only = local_files_only )
1331- tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" )
1345+ text_model = convert_ldm_clip_checkpoint (
1346+ checkpoint , local_files_only = local_files_only , text_encoder = text_encoder
1347+ )
1348+ tokenizer = CLIPTokenizer .from_pretrained ("openai/clip-vit-large-patch14" ) if tokenizer is None else tokenizer
13321349
13331350 if load_safety_checker :
13341351 safety_checker = StableDiffusionSafetyChecker .from_pretrained ("CompVis/stable-diffusion-safety-checker" )
0 commit comments