Skip to content

Commit 8bff782

Browse files
Improve single loading file (huggingface#4041)
* start improving single file load * Fix more * start improving single file load * Fix sd 2.1 * further improve from_single_file
1 parent 6632823 commit 8bff782

File tree

5 files changed

+164
-35
lines changed

5 files changed

+164
-35
lines changed

src/diffusers/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
13891389
use_auth_token = kwargs.pop("use_auth_token", None)
13901390
revision = kwargs.pop("revision", None)
13911391
extract_ema = kwargs.pop("extract_ema", False)
1392-
image_size = kwargs.pop("image_size", 512)
1392+
image_size = kwargs.pop("image_size", None)
13931393
scheduler_type = kwargs.pop("scheduler_type", "pndm")
13941394
num_in_channels = kwargs.pop("num_in_channels", None)
13951395
upcast_attention = kwargs.pop("upcast_attention", None)

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
AutoFeatureExtractor,
2525
BertTokenizerFast,
2626
CLIPImageProcessor,
27+
CLIPTextConfig,
2728
CLIPTextModel,
2829
CLIPTextModelWithProjection,
2930
CLIPTokenizer,
@@ -48,7 +49,7 @@
4849
PNDMScheduler,
4950
UnCLIPScheduler,
5051
)
51-
from ...utils import is_omegaconf_available, is_safetensors_available, logging
52+
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
5253
from ...utils.import_utils import BACKENDS_MAPPING
5354
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
5455
from ..paint_by_example import PaintByExampleImageEncoder
@@ -57,6 +58,10 @@
5758
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
5859

5960

61+
if is_accelerate_available():
62+
from accelerate import init_empty_weights
63+
from accelerate.utils import set_module_tensor_to_device
64+
6065
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6166

6267

@@ -770,11 +775,12 @@ def _copy_layers(hf_layers, pt_layers):
770775

771776

772777
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
773-
text_model = (
774-
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
775-
if text_encoder is None
776-
else text_encoder
777-
)
778+
if text_encoder is None:
779+
config_name = "openai/clip-vit-large-patch14"
780+
config = CLIPTextConfig.from_pretrained(config_name)
781+
782+
with init_empty_weights():
783+
text_model = CLIPTextModel(config)
778784

779785
keys = list(checkpoint.keys())
780786

@@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
787793
if key.startswith(prefix):
788794
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
789795

790-
text_model.load_state_dict(text_model_dict)
796+
for param_name, param in text_model_dict.items():
797+
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
791798

792799
return text_model
793800

@@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint):
884891
return model
885892

886893

887-
def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
894+
def convert_open_clip_checkpoint(
895+
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
896+
):
888897
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
889-
text_model = CLIPTextModelWithProjection.from_pretrained(
890-
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
891-
)
898+
# text_model = CLIPTextModelWithProjection.from_pretrained(
899+
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
900+
# )
901+
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
902+
903+
with init_empty_weights():
904+
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
892905

893906
keys = list(checkpoint.keys())
894907

908+
keys_to_ignore = []
909+
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
910+
# make sure to remove all keys > 22
911+
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
912+
keys_to_ignore += ["cond_stage_model.model.text_projection"]
913+
895914
text_model_dict = {}
896915

897916
if prefix + "text_projection" in checkpoint:
@@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
902921
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
903922

904923
for key in keys:
905-
# if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
906-
# continue
924+
if key in keys_to_ignore:
925+
continue
907926
if key[len(prefix) :] in textenc_conversion_map:
908927
if key.endswith("text_projection"):
909928
value = checkpoint[key].T
@@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
931950

932951
text_model_dict[new_key] = checkpoint[key]
933952

934-
text_model.load_state_dict(text_model_dict)
953+
for param_name, param in text_model_dict.items():
954+
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
935955

936956
return text_model
937957

@@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint(
10611081
def download_from_original_stable_diffusion_ckpt(
10621082
checkpoint_path: str,
10631083
original_config_file: str = None,
1064-
image_size: int = 512,
1084+
image_size: Optional[int] = None,
10651085
prediction_type: str = None,
10661086
model_type: str = None,
10671087
extract_ema: bool = False,
@@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt(
11441164
LDMTextToImagePipeline,
11451165
PaintByExamplePipeline,
11461166
StableDiffusionControlNetPipeline,
1167+
StableDiffusionInpaintPipeline,
11471168
StableDiffusionPipeline,
11481169
StableDiffusionXLImg2ImgPipeline,
11491170
StableDiffusionXLPipeline,
@@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
11661187
if not is_safetensors_available():
11671188
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
11681189

1169-
from safetensors import safe_open
1190+
from safetensors.torch import load_file as safe_load
11701191

1171-
checkpoint = {}
1172-
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
1173-
for key in f.keys():
1174-
checkpoint[key] = f.get_tensor(key)
1192+
checkpoint = safe_load(checkpoint_path, device="cpu")
11751193
else:
11761194
if device is None:
11771195
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
11831201
if "global_step" in checkpoint:
11841202
global_step = checkpoint["global_step"]
11851203
else:
1186-
logger.warning("global_step key not found in model")
1204+
logger.debug("global_step key not found in model")
11871205
global_step = None
11881206

11891207
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
@@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt(
12301248
model_type = "SDXL"
12311249
else:
12321250
model_type = "SDXL-Refiner"
1251+
if image_size is None:
1252+
image_size = 1024
12331253

1234-
if num_in_channels is not None:
1235-
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
1254+
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
1255+
num_in_channels = 9
1256+
elif num_in_channels is None:
1257+
num_in_channels = 4
1258+
1259+
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
12361260

12371261
if (
12381262
"parameterization" in original_config["model"]["params"]
@@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt(
12631287
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
12641288

12651289
if model_type in ["SDXL", "SDXL-Refiner"]:
1266-
image_size = 1024
12671290
scheduler_dict = {
12681291
"beta_schedule": "scaled_linear",
12691292
"beta_start": 0.00085,
@@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt(
12791302
}
12801303
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
12811304
scheduler_type = "euler"
1282-
vae_path = "stabilityai/sdxl-vae"
12831305
else:
12841306
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
12851307
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
@@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt(
13181340
# Convert the UNet2DConditionModel model.
13191341
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
13201342
unet_config["upcast_attention"] = upcast_attention
1321-
unet = UNet2DConditionModel(**unet_config)
1343+
with init_empty_weights():
1344+
unet = UNet2DConditionModel(**unet_config)
13221345

13231346
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
13241347
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
13251348
)
1326-
unet.load_state_dict(converted_unet_checkpoint)
1349+
1350+
for param_name, param in converted_unet_checkpoint.items():
1351+
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
13271352

13281353
# Convert the VAE model.
13291354
if vae_path is None:
13301355
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
13311356
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
13321357

1333-
vae = AutoencoderKL(**vae_config)
1334-
vae.load_state_dict(converted_vae_checkpoint)
1358+
if (
1359+
"model" in original_config
1360+
and "params" in original_config.model
1361+
and "scale_factor" in original_config.model.params
1362+
):
1363+
vae_scaling_factor = original_config.model.params.scale_factor
1364+
else:
1365+
vae_scaling_factor = 0.18215 # default SD scaling factor
1366+
1367+
vae_config["scaling_factor"] = vae_scaling_factor
1368+
1369+
with init_empty_weights():
1370+
vae = AutoencoderKL(**vae_config)
1371+
1372+
for param_name, param in converted_vae_checkpoint.items():
1373+
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
13351374
else:
13361375
vae = AutoencoderKL.from_pretrained(vae_path)
13371376

13381377
if model_type == "FrozenOpenCLIPEmbedder":
1339-
text_model = convert_open_clip_checkpoint(checkpoint)
1378+
config_name = "stabilityai/stable-diffusion-2"
1379+
config_kwargs = {"subfolder": "text_encoder"}
1380+
1381+
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
13401382
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
13411383

13421384
if stable_unclip is None:
@@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt(
14691511
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
14701512
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
14711513
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
1472-
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
1514+
1515+
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1516+
config_kwargs = {"projection_dim": 1280}
1517+
text_encoder_2 = convert_open_clip_checkpoint(
1518+
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
1519+
)
14731520

14741521
pipe = StableDiffusionXLPipeline(
14751522
vae=vae,
@@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt(
14851532
tokenizer = None
14861533
text_encoder = None
14871534
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
1488-
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.")
1535+
1536+
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
1537+
config_kwargs = {"projection_dim": 1280}
1538+
text_encoder_2 = convert_open_clip_checkpoint(
1539+
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
1540+
)
14891541

14901542
pipe = StableDiffusionXLImg2ImgPipeline(
14911543
vae=vae,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ...configuration_utils import FrozenDict
2626
from ...image_processor import VaeImageProcessor
27-
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27+
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2828
from ...models import AutoencoderKL, UNet2DConditionModel
2929
from ...schedulers import KarrasDiffusionSchedulers
3030
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
153153
return mask, masked_image
154154

155155

156-
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
156+
class StableDiffusionInpaintPipeline(
157+
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
158+
):
157159
r"""
158160
Pipeline for text-guided image inpainting using Stable Diffusion.
159161

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@
2020

2121
import numpy as np
2222
import torch
23+
from huggingface_hub import hf_hub_download
2324
from PIL import Image
2425
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
2526

2627
from diffusers import (
2728
AutoencoderKL,
29+
DDIMScheduler,
2830
DPMSolverMultistepScheduler,
2931
LMSDiscreteScheduler,
3032
PNDMScheduler,
3133
StableDiffusionInpaintPipeline,
3234
UNet2DConditionModel,
3335
)
36+
from diffusers.models.attention_processor import AttnProcessor
3437
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
3538
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
3639
from diffusers.utils.testing_utils import (
@@ -512,6 +515,42 @@ def test_stable_diffusion_simple_inpaint_ddim(self):
512515

513516
assert np.abs(expected_slice - image_slice).max() < 6e-4
514517

518+
def test_download_local(self):
519+
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
520+
521+
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
522+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
523+
pipe.to("cuda")
524+
525+
inputs = self.get_inputs(torch_device)
526+
inputs["num_inference_steps"] = 1
527+
image_out = pipe(**inputs).images[0]
528+
529+
assert image_out.shape == (512, 512, 3)
530+
531+
def test_download_ckpt_diff_format_is_same(self):
532+
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
533+
534+
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
535+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
536+
pipe.unet.set_attn_processor(AttnProcessor())
537+
pipe.to("cuda")
538+
539+
inputs = self.get_inputs(torch_device)
540+
inputs["num_inference_steps"] = 5
541+
image_ckpt = pipe(**inputs).images[0]
542+
543+
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
544+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
545+
pipe.unet.set_attn_processor(AttnProcessor())
546+
pipe.to("cuda")
547+
548+
inputs = self.get_inputs(torch_device)
549+
inputs["num_inference_steps"] = 5
550+
image = pipe(**inputs).images[0]
551+
552+
assert np.max(np.abs(image - image_ckpt)) < 1e-4
553+
515554

516555
@nightly
517556
@require_torch_gpu

0 commit comments

Comments
 (0)