|
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | 15 | from collections import defaultdict |
| 16 | +from pathlib import Path |
16 | 17 | from typing import Callable, Dict, List, Optional, Union |
17 | 18 |
|
18 | 19 | import torch |
| 20 | +from huggingface_hub import hf_hub_download |
19 | 21 |
|
20 | 22 | from .models.attention_processor import LoRAAttnProcessor |
21 | 23 | from .utils import ( |
@@ -431,6 +433,7 @@ def load_textual_inversion( |
431 | 433 | Example: |
432 | 434 |
|
433 | 435 | To load a textual inversion embedding vector in `diffusers` format: |
| 436 | +
|
434 | 437 | ```py |
435 | 438 | from diffusers import StableDiffusionPipeline |
436 | 439 | import torch |
@@ -463,6 +466,7 @@ def load_textual_inversion( |
463 | 466 | image = pipe(prompt, num_inference_steps=50).images[0] |
464 | 467 | image.save("character.png") |
465 | 468 | ``` |
| 469 | +
|
466 | 470 | """ |
467 | 471 | if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer): |
468 | 472 | raise ValueError( |
@@ -1051,3 +1055,197 @@ def save_function(weights, filename): |
1051 | 1055 |
|
1052 | 1056 | save_function(state_dict, os.path.join(save_directory, weight_name)) |
1053 | 1057 | logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") |
| 1058 | + |
| 1059 | + |
| 1060 | +class FromCkptMixin: |
| 1061 | + """This helper class allows to directly load .ckpt stable diffusion file_extension |
| 1062 | + into the respective classes.""" |
| 1063 | + |
| 1064 | + @classmethod |
| 1065 | + def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): |
| 1066 | + r""" |
| 1067 | + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format. |
| 1068 | +
|
| 1069 | + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). |
| 1070 | +
|
| 1071 | + Parameters: |
| 1072 | + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): |
| 1073 | + Can be either: |
| 1074 | + - A link to the .ckpt file on the Hub. Should be in the format |
| 1075 | + `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"` |
| 1076 | + - A path to a *file* containing all pipeline weights. |
| 1077 | + torch_dtype (`str` or `torch.dtype`, *optional*): |
| 1078 | + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype |
| 1079 | + will be automatically derived from the model's weights. |
| 1080 | + force_download (`bool`, *optional*, defaults to `False`): |
| 1081 | + Whether or not to force the (re-)download of the model weights and configuration files, overriding the |
| 1082 | + cached versions if they exist. |
| 1083 | + cache_dir (`Union[str, os.PathLike]`, *optional*): |
| 1084 | + Path to a directory in which a downloaded pretrained model configuration should be cached if the |
| 1085 | + standard cache should not be used. |
| 1086 | + resume_download (`bool`, *optional*, defaults to `False`): |
| 1087 | + Whether or not to delete incompletely received files. Will attempt to resume the download if such a |
| 1088 | + file exists. |
| 1089 | + proxies (`Dict[str, str]`, *optional*): |
| 1090 | + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', |
| 1091 | + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. |
| 1092 | + local_files_only (`bool`, *optional*, defaults to `False`): |
| 1093 | + Whether or not to only look at local files (i.e., do not try to download the model). |
| 1094 | + use_auth_token (`str` or *bool*, *optional*): |
| 1095 | + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated |
| 1096 | + when running `huggingface-cli login` (stored in `~/.huggingface`). |
| 1097 | + revision (`str`, *optional*, defaults to `"main"`): |
| 1098 | + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a |
| 1099 | + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any |
| 1100 | + identifier allowed by git. |
| 1101 | + use_safetensors (`bool`, *optional* ): |
| 1102 | + If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the |
| 1103 | + default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if |
| 1104 | + `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. |
| 1105 | + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for |
| 1106 | + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults |
| 1107 | + to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for |
| 1108 | + inference. Non-EMA weights are usually better to continue fine-tuning. |
| 1109 | + upcast_attention (`bool`, *optional*, defaults to `None`): |
| 1110 | + Whether the attention computation should always be upcasted. This is necessary when running stable |
| 1111 | + image_size (`int`, *optional*, defaults to 512): |
| 1112 | + The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2 |
| 1113 | + Base. Use 768 for Stable Diffusion v2. |
| 1114 | + prediction_type (`str`, *optional*): |
| 1115 | + The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable |
| 1116 | + Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2. |
| 1117 | + num_in_channels (`int`, *optional*, defaults to None): |
| 1118 | + The number of input channels. If `None`, it will be automatically inferred. |
| 1119 | + scheduler_type (`str`, *optional*, defaults to 'pndm'): |
| 1120 | + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", |
| 1121 | + "ddim"]`. |
| 1122 | + load_safety_checker (`bool`, *optional*, defaults to `True`): |
| 1123 | + Whether to load the safety checker or not. Defaults to `True`. |
| 1124 | + kwargs (remaining dictionary of keyword arguments, *optional*): |
| 1125 | + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the |
| 1126 | + specific pipeline class. The overwritten components are then directly passed to the pipelines |
| 1127 | + `__init__` method. See example below for more information. |
| 1128 | +
|
| 1129 | + Examples: |
| 1130 | +
|
| 1131 | + ```py |
| 1132 | + >>> from diffusers import StableDiffusionPipeline |
| 1133 | +
|
| 1134 | + >>> # Download pipeline from huggingface.co and cache. |
| 1135 | + >>> pipeline = StableDiffusionPipeline.from_ckpt( |
| 1136 | + ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" |
| 1137 | + ... ) |
| 1138 | +
|
| 1139 | + >>> # Download pipeline from local file |
| 1140 | + >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt |
| 1141 | + >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly") |
| 1142 | +
|
| 1143 | + >>> # Enable float16 and move to GPU |
| 1144 | + >>> pipeline = StableDiffusionPipeline.from_ckpt( |
| 1145 | + ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", |
| 1146 | + ... torch_dtype=torch.float16, |
| 1147 | + ... ) |
| 1148 | + >>> pipeline.to("cuda") |
| 1149 | + ``` |
| 1150 | + """ |
| 1151 | + # import here to avoid circular dependency |
| 1152 | + from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt |
| 1153 | + |
| 1154 | + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
| 1155 | + resume_download = kwargs.pop("resume_download", False) |
| 1156 | + force_download = kwargs.pop("force_download", False) |
| 1157 | + proxies = kwargs.pop("proxies", None) |
| 1158 | + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) |
| 1159 | + use_auth_token = kwargs.pop("use_auth_token", None) |
| 1160 | + revision = kwargs.pop("revision", None) |
| 1161 | + extract_ema = kwargs.pop("extract_ema", False) |
| 1162 | + image_size = kwargs.pop("image_size", 512) |
| 1163 | + scheduler_type = kwargs.pop("scheduler_type", "pndm") |
| 1164 | + num_in_channels = kwargs.pop("num_in_channels", None) |
| 1165 | + upcast_attention = kwargs.pop("upcast_attention", None) |
| 1166 | + load_safety_checker = kwargs.pop("load_safety_checker", True) |
| 1167 | + prediction_type = kwargs.pop("prediction_type", None) |
| 1168 | + |
| 1169 | + torch_dtype = kwargs.pop("torch_dtype", None) |
| 1170 | + |
| 1171 | + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) |
| 1172 | + |
| 1173 | + pipeline_name = cls.__name__ |
| 1174 | + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] |
| 1175 | + from_safetensors = file_extension == "safetensors" |
| 1176 | + |
| 1177 | + if from_safetensors and use_safetensors is True: |
| 1178 | + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") |
| 1179 | + |
| 1180 | + # TODO: For now we only support stable diffusion |
| 1181 | + stable_unclip = None |
| 1182 | + controlnet = False |
| 1183 | + |
| 1184 | + if pipeline_name == "StableDiffusionControlNetPipeline": |
| 1185 | + model_type = "FrozenCLIPEmbedder" |
| 1186 | + controlnet = True |
| 1187 | + elif "StableDiffusion" in pipeline_name: |
| 1188 | + model_type = "FrozenCLIPEmbedder" |
| 1189 | + elif pipeline_name == "StableUnCLIPPipeline": |
| 1190 | + model_type == "FrozenOpenCLIPEmbedder" |
| 1191 | + stable_unclip = "txt2img" |
| 1192 | + elif pipeline_name == "StableUnCLIPImg2ImgPipeline": |
| 1193 | + model_type == "FrozenOpenCLIPEmbedder" |
| 1194 | + stable_unclip = "img2img" |
| 1195 | + elif pipeline_name == "PaintByExamplePipeline": |
| 1196 | + model_type == "PaintByExample" |
| 1197 | + elif pipeline_name == "LDMTextToImagePipeline": |
| 1198 | + model_type == "LDMTextToImage" |
| 1199 | + else: |
| 1200 | + raise ValueError(f"Unhandled pipeline class: {pipeline_name}") |
| 1201 | + |
| 1202 | + # remove huggingface url |
| 1203 | + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: |
| 1204 | + if pretrained_model_link_or_path.startswith(prefix): |
| 1205 | + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] |
| 1206 | + |
| 1207 | + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained |
| 1208 | + ckpt_path = Path(pretrained_model_link_or_path) |
| 1209 | + if not ckpt_path.is_file(): |
| 1210 | + # get repo_id and (potentially nested) file path of ckpt in repo |
| 1211 | + repo_id = str(Path().joinpath(*ckpt_path.parts[:2])) |
| 1212 | + file_path = str(Path().joinpath(*ckpt_path.parts[2:])) |
| 1213 | + |
| 1214 | + if file_path.startswith("blob/"): |
| 1215 | + file_path = file_path[len("blob/") :] |
| 1216 | + |
| 1217 | + if file_path.startswith("main/"): |
| 1218 | + file_path = file_path[len("main/") :] |
| 1219 | + |
| 1220 | + pretrained_model_link_or_path = hf_hub_download( |
| 1221 | + repo_id, |
| 1222 | + filename=file_path, |
| 1223 | + cache_dir=cache_dir, |
| 1224 | + resume_download=resume_download, |
| 1225 | + proxies=proxies, |
| 1226 | + local_files_only=local_files_only, |
| 1227 | + use_auth_token=use_auth_token, |
| 1228 | + revision=revision, |
| 1229 | + force_download=force_download, |
| 1230 | + ) |
| 1231 | + |
| 1232 | + pipe = download_from_original_stable_diffusion_ckpt( |
| 1233 | + pretrained_model_link_or_path, |
| 1234 | + pipeline_class=cls, |
| 1235 | + model_type=model_type, |
| 1236 | + stable_unclip=stable_unclip, |
| 1237 | + controlnet=controlnet, |
| 1238 | + from_safetensors=from_safetensors, |
| 1239 | + extract_ema=extract_ema, |
| 1240 | + image_size=image_size, |
| 1241 | + scheduler_type=scheduler_type, |
| 1242 | + num_in_channels=num_in_channels, |
| 1243 | + upcast_attention=upcast_attention, |
| 1244 | + load_safety_checker=load_safety_checker, |
| 1245 | + prediction_type=prediction_type, |
| 1246 | + ) |
| 1247 | + |
| 1248 | + if torch_dtype is not None: |
| 1249 | + pipe.to(torch_dtype=torch_dtype) |
| 1250 | + |
| 1251 | + return pipe |
0 commit comments