Skip to content

Commit 86ecd4b

Browse files
add from_ckpt method as Mixin (huggingface#2318)
* add mixin class for pipeline from original sd ckpt * Improve * make style * merge main into * Improve more * fix more * up * Apply suggestions from code review * finish docs * rename * make style --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent bdeff4d commit 86ecd4b

21 files changed

+410
-125
lines changed

docs/source/en/api/loaders.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
3636
### LoraLoaderMixin
3737

3838
[[autodoc]] loaders.LoraLoaderMixin
39+
40+
### FromCkptMixin
41+
42+
[[autodoc]] loaders.FromCkptMixin

docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
308308
- disable_vae_slicing
309309
- enable_xformers_memory_efficient_attention
310310
- disable_xformers_memory_efficient_attention
311+
- load_textual_inversion
311312

312313
## FlaxStableDiffusionControlNetPipeline
313314
[[autodoc]] FlaxStableDiffusionControlNetPipeline

docs/source/en/api/pipelines/stable_diffusion/depth2img.mdx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,7 @@ Available Checkpoints are:
3030
- enable_attention_slicing
3131
- disable_attention_slicing
3232
- enable_xformers_memory_efficient_attention
33-
- disable_xformers_memory_efficient_attention
33+
- disable_xformers_memory_efficient_attention
34+
- load_textual_inversion
35+
- load_lora_weights
36+
- save_lora_weights

docs/source/en/api/pipelines/stable_diffusion/img2img.mdx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
3030
- disable_attention_slicing
3131
- enable_xformers_memory_efficient_attention
3232
- disable_xformers_memory_efficient_attention
33+
- load_textual_inversion
34+
- from_ckpt
35+
- load_lora_weights
36+
- save_lora_weights
3337

3438
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
3539
- all
36-
- __call__
40+
- __call__

docs/source/en/api/pipelines/stable_diffusion/inpaint.mdx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ Available checkpoints are:
3131
- disable_attention_slicing
3232
- enable_xformers_memory_efficient_attention
3333
- disable_xformers_memory_efficient_attention
34+
- load_textual_inversion
35+
- load_lora_weights
36+
- save_lora_weights
3437

3538
[[autodoc]] FlaxStableDiffusionInpaintPipeline
3639
- all
37-
- __call__
40+
- __call__

docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png")
6868
[[autodoc]] StableDiffusionInstructPix2PixPipeline
6969
- __call__
7070
- all
71+
- load_textual_inversion
72+
- load_lora_weights
73+
- save_lora_weights

docs/source/en/api/pipelines/stable_diffusion/text2img.mdx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ Available Checkpoints are:
3939
- disable_xformers_memory_efficient_attention
4040
- enable_vae_tiling
4141
- disable_vae_tiling
42+
- load_textual_inversion
43+
- from_ckpt
44+
- load_lora_weights
45+
- save_lora_weights
4246

4347
[[autodoc]] FlaxStableDiffusionPipeline
4448
- all

src/diffusers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@
109109
except OptionalDependencyNotAvailable:
110110
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
111111
else:
112-
from .loaders import TextualInversionLoaderMixin
113112
from .pipelines import (
114113
AltDiffusionImg2ImgPipeline,
115114
AltDiffusionPipeline,

src/diffusers/loaders.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
import os
1515
from collections import defaultdict
16+
from pathlib import Path
1617
from typing import Callable, Dict, List, Optional, Union
1718

1819
import torch
20+
from huggingface_hub import hf_hub_download
1921

2022
from .models.attention_processor import LoRAAttnProcessor
2123
from .utils import (
@@ -431,6 +433,7 @@ def load_textual_inversion(
431433
Example:
432434
433435
To load a textual inversion embedding vector in `diffusers` format:
436+
434437
```py
435438
from diffusers import StableDiffusionPipeline
436439
import torch
@@ -463,6 +466,7 @@ def load_textual_inversion(
463466
image = pipe(prompt, num_inference_steps=50).images[0]
464467
image.save("character.png")
465468
```
469+
466470
"""
467471
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
468472
raise ValueError(
@@ -1051,3 +1055,197 @@ def save_function(weights, filename):
10511055

10521056
save_function(state_dict, os.path.join(save_directory, weight_name))
10531057
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1058+
1059+
1060+
class FromCkptMixin:
1061+
"""This helper class allows to directly load .ckpt stable diffusion file_extension
1062+
into the respective classes."""
1063+
1064+
@classmethod
1065+
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
1066+
r"""
1067+
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
1068+
1069+
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
1070+
1071+
Parameters:
1072+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1073+
Can be either:
1074+
- A link to the .ckpt file on the Hub. Should be in the format
1075+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
1076+
- A path to a *file* containing all pipeline weights.
1077+
torch_dtype (`str` or `torch.dtype`, *optional*):
1078+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
1079+
will be automatically derived from the model's weights.
1080+
force_download (`bool`, *optional*, defaults to `False`):
1081+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1082+
cached versions if they exist.
1083+
cache_dir (`Union[str, os.PathLike]`, *optional*):
1084+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
1085+
standard cache should not be used.
1086+
resume_download (`bool`, *optional*, defaults to `False`):
1087+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1088+
file exists.
1089+
proxies (`Dict[str, str]`, *optional*):
1090+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1091+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1092+
local_files_only (`bool`, *optional*, defaults to `False`):
1093+
Whether or not to only look at local files (i.e., do not try to download the model).
1094+
use_auth_token (`str` or *bool*, *optional*):
1095+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1096+
when running `huggingface-cli login` (stored in `~/.huggingface`).
1097+
revision (`str`, *optional*, defaults to `"main"`):
1098+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1099+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1100+
identifier allowed by git.
1101+
use_safetensors (`bool`, *optional* ):
1102+
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
1103+
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
1104+
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
1105+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
1106+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
1107+
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
1108+
inference. Non-EMA weights are usually better to continue fine-tuning.
1109+
upcast_attention (`bool`, *optional*, defaults to `None`):
1110+
Whether the attention computation should always be upcasted. This is necessary when running stable
1111+
image_size (`int`, *optional*, defaults to 512):
1112+
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
1113+
Base. Use 768 for Stable Diffusion v2.
1114+
prediction_type (`str`, *optional*):
1115+
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
1116+
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
1117+
num_in_channels (`int`, *optional*, defaults to None):
1118+
The number of input channels. If `None`, it will be automatically inferred.
1119+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
1120+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1121+
"ddim"]`.
1122+
load_safety_checker (`bool`, *optional*, defaults to `True`):
1123+
Whether to load the safety checker or not. Defaults to `True`.
1124+
kwargs (remaining dictionary of keyword arguments, *optional*):
1125+
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
1126+
specific pipeline class. The overwritten components are then directly passed to the pipelines
1127+
`__init__` method. See example below for more information.
1128+
1129+
Examples:
1130+
1131+
```py
1132+
>>> from diffusers import StableDiffusionPipeline
1133+
1134+
>>> # Download pipeline from huggingface.co and cache.
1135+
>>> pipeline = StableDiffusionPipeline.from_ckpt(
1136+
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1137+
... )
1138+
1139+
>>> # Download pipeline from local file
1140+
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1141+
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
1142+
1143+
>>> # Enable float16 and move to GPU
1144+
>>> pipeline = StableDiffusionPipeline.from_ckpt(
1145+
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1146+
... torch_dtype=torch.float16,
1147+
... )
1148+
>>> pipeline.to("cuda")
1149+
```
1150+
"""
1151+
# import here to avoid circular dependency
1152+
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1153+
1154+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1155+
resume_download = kwargs.pop("resume_download", False)
1156+
force_download = kwargs.pop("force_download", False)
1157+
proxies = kwargs.pop("proxies", None)
1158+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1159+
use_auth_token = kwargs.pop("use_auth_token", None)
1160+
revision = kwargs.pop("revision", None)
1161+
extract_ema = kwargs.pop("extract_ema", False)
1162+
image_size = kwargs.pop("image_size", 512)
1163+
scheduler_type = kwargs.pop("scheduler_type", "pndm")
1164+
num_in_channels = kwargs.pop("num_in_channels", None)
1165+
upcast_attention = kwargs.pop("upcast_attention", None)
1166+
load_safety_checker = kwargs.pop("load_safety_checker", True)
1167+
prediction_type = kwargs.pop("prediction_type", None)
1168+
1169+
torch_dtype = kwargs.pop("torch_dtype", None)
1170+
1171+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1172+
1173+
pipeline_name = cls.__name__
1174+
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1175+
from_safetensors = file_extension == "safetensors"
1176+
1177+
if from_safetensors and use_safetensors is True:
1178+
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1179+
1180+
# TODO: For now we only support stable diffusion
1181+
stable_unclip = None
1182+
controlnet = False
1183+
1184+
if pipeline_name == "StableDiffusionControlNetPipeline":
1185+
model_type = "FrozenCLIPEmbedder"
1186+
controlnet = True
1187+
elif "StableDiffusion" in pipeline_name:
1188+
model_type = "FrozenCLIPEmbedder"
1189+
elif pipeline_name == "StableUnCLIPPipeline":
1190+
model_type == "FrozenOpenCLIPEmbedder"
1191+
stable_unclip = "txt2img"
1192+
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1193+
model_type == "FrozenOpenCLIPEmbedder"
1194+
stable_unclip = "img2img"
1195+
elif pipeline_name == "PaintByExamplePipeline":
1196+
model_type == "PaintByExample"
1197+
elif pipeline_name == "LDMTextToImagePipeline":
1198+
model_type == "LDMTextToImage"
1199+
else:
1200+
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1201+
1202+
# remove huggingface url
1203+
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1204+
if pretrained_model_link_or_path.startswith(prefix):
1205+
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1206+
1207+
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1208+
ckpt_path = Path(pretrained_model_link_or_path)
1209+
if not ckpt_path.is_file():
1210+
# get repo_id and (potentially nested) file path of ckpt in repo
1211+
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
1212+
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
1213+
1214+
if file_path.startswith("blob/"):
1215+
file_path = file_path[len("blob/") :]
1216+
1217+
if file_path.startswith("main/"):
1218+
file_path = file_path[len("main/") :]
1219+
1220+
pretrained_model_link_or_path = hf_hub_download(
1221+
repo_id,
1222+
filename=file_path,
1223+
cache_dir=cache_dir,
1224+
resume_download=resume_download,
1225+
proxies=proxies,
1226+
local_files_only=local_files_only,
1227+
use_auth_token=use_auth_token,
1228+
revision=revision,
1229+
force_download=force_download,
1230+
)
1231+
1232+
pipe = download_from_original_stable_diffusion_ckpt(
1233+
pretrained_model_link_or_path,
1234+
pipeline_class=cls,
1235+
model_type=model_type,
1236+
stable_unclip=stable_unclip,
1237+
controlnet=controlnet,
1238+
from_safetensors=from_safetensors,
1239+
extract_ema=extract_ema,
1240+
image_size=image_size,
1241+
scheduler_type=scheduler_type,
1242+
num_in_channels=num_in_channels,
1243+
upcast_attention=upcast_attention,
1244+
load_safety_checker=load_safety_checker,
1245+
prediction_type=prediction_type,
1246+
)
1247+
1248+
if torch_dtype is not None:
1249+
pipe.to(torch_dtype=torch_dtype)
1250+
1251+
return pipe

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
5757
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
5858
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
5959
60+
In addition the pipeline inherits the following loading methods:
61+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
62+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
63+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
64+
65+
as well as the following saving methods:
66+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
67+
6068
Args:
6169
vae ([`AutoencoderKL`]):
6270
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.

0 commit comments

Comments
 (0)