Skip to content

Commit 1500130

Browse files
Support for manual CLIP loading in StableDiffusionPipeline - txt2img. (huggingface#3832)
* Support for manual CLIP loading in StableDiffusionPipeline - txt2img. * Update src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py * Update variables & according docs to match previous style. * Updated to match style & quality of 'diffusers' --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 219636f commit 1500130

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

src/diffusers/loaders.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,6 +1339,17 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
13391339
"ddim"]`.
13401340
load_safety_checker (`bool`, *optional*, defaults to `True`):
13411341
Whether to load the safety checker or not.
1342+
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1343+
An instance of
1344+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
1345+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1346+
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
1347+
needed.
1348+
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1349+
An instance of
1350+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1351+
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
1352+
itself, if needed.
13421353
kwargs (remaining dictionary of keyword arguments, *optional*):
13431354
Can be used to overwrite load and saveable variables (for example the pipeline components of the
13441355
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -1383,6 +1394,8 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
13831394
upcast_attention = kwargs.pop("upcast_attention", None)
13841395
load_safety_checker = kwargs.pop("load_safety_checker", True)
13851396
prediction_type = kwargs.pop("prediction_type", None)
1397+
text_encoder = kwargs.pop("text_encoder", None)
1398+
tokenizer = kwargs.pop("tokenizer", None)
13861399

13871400
torch_dtype = kwargs.pop("torch_dtype", None)
13881401

@@ -1463,6 +1476,8 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
14631476
upcast_attention=upcast_attention,
14641477
load_safety_checker=load_safety_checker,
14651478
prediction_type=prediction_type,
1479+
text_encoder=text_encoder,
1480+
tokenizer=tokenizer,
14661481
)
14671482

14681483
if torch_dtype is not None:

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,12 @@ def _copy_layers(hf_layers, pt_layers):
734734
return hf_model
735735

736736

737-
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
738-
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
737+
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
738+
text_model = (
739+
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
740+
if text_encoder is None
741+
else text_encoder
742+
)
739743

740744
keys = list(checkpoint.keys())
741745

@@ -1025,6 +1029,8 @@ def download_from_original_stable_diffusion_ckpt(
10251029
load_safety_checker: bool = True,
10261030
pipeline_class: DiffusionPipeline = None,
10271031
local_files_only=False,
1032+
text_encoder=None,
1033+
tokenizer=None,
10281034
) -> DiffusionPipeline:
10291035
"""
10301036
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1072,6 +1078,15 @@ def download_from_original_stable_diffusion_ckpt(
10721078
The pipeline class to use. Pass `None` to determine automatically.
10731079
local_files_only (`bool`, *optional*, defaults to `False`):
10741080
Whether or not to only look at local files (i.e., do not try to download the model).
1081+
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1082+
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
1083+
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1084+
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
1085+
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1086+
An instance of
1087+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1088+
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
1089+
needed.
10751090
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
10761091
"""
10771092

@@ -1327,8 +1342,10 @@ def download_from_original_stable_diffusion_ckpt(
13271342
feature_extractor=feature_extractor,
13281343
)
13291344
elif model_type == "FrozenCLIPEmbedder":
1330-
text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
1331-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1345+
text_model = convert_ldm_clip_checkpoint(
1346+
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
1347+
)
1348+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer
13321349

13331350
if load_safety_checker:
13341351
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

0 commit comments

Comments
 (0)