Skip to content

Commit 4517117

Browse files
authored
Prevent online access when desired when using download_from_original_stable_diffusion_ckpt (huggingface#4271)
Prevent online access when desired - Bypass requests with config files option added to download_from_original_stable_diffusion_ckpt - Adds local_files_only flags to all from_pretrained requests
1 parent 4c4fe04 commit 4517117

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ def _copy_layers(hf_layers, pt_layers):
778778
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
779779
if text_encoder is None:
780780
config_name = "openai/clip-vit-large-patch14"
781-
config = CLIPTextConfig.from_pretrained(config_name)
781+
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
782782

783783
ctx = init_empty_weights if is_accelerate_available() else nullcontext
784784
with ctx():
@@ -832,8 +832,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
832832
textenc_pattern = re.compile("|".join(protected.keys()))
833833

834834

835-
def convert_paint_by_example_checkpoint(checkpoint):
836-
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
835+
def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
836+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
837837
model = PaintByExampleImageEncoder(config)
838838

839839
keys = list(checkpoint.keys())
@@ -900,13 +900,13 @@ def convert_paint_by_example_checkpoint(checkpoint):
900900

901901

902902
def convert_open_clip_checkpoint(
903-
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
903+
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, local_files_only=False, **config_kwargs
904904
):
905905
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
906906
# text_model = CLIPTextModelWithProjection.from_pretrained(
907907
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
908908
# )
909-
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
909+
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
910910

911911
ctx = init_empty_weights if is_accelerate_available() else nullcontext
912912
with ctx():
@@ -971,7 +971,7 @@ def convert_open_clip_checkpoint(
971971
return text_model
972972

973973

974-
def stable_unclip_image_encoder(original_config):
974+
def stable_unclip_image_encoder(original_config, local_files_only=False):
975975
"""
976976
Returns the image processor and clip image encoder for the img2img unclip pipeline.
977977
@@ -989,13 +989,13 @@ def stable_unclip_image_encoder(original_config):
989989

990990
if clip_model_name == "ViT-L/14":
991991
feature_extractor = CLIPImageProcessor()
992-
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
992+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
993993
else:
994994
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
995995

996996
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
997997
feature_extractor = CLIPImageProcessor()
998-
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
998+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only)
999999
else:
10001000
raise NotImplementedError(
10011001
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
@@ -1116,6 +1116,7 @@ def download_from_original_stable_diffusion_ckpt(
11161116
vae=None,
11171117
text_encoder=None,
11181118
tokenizer=None,
1119+
config_files=None,
11191120
) -> DiffusionPipeline:
11201121
"""
11211122
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1175,6 +1176,14 @@ def download_from_original_stable_diffusion_ckpt(
11751176
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
11761177
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
11771178
needed.
1179+
config_files (`Dict[str, str]`, *optional*, defaults to `None`):
1180+
A dictionary mapping from config file names to their contents. If this parameter is `None`, the function
1181+
will load the config files by itself, if needed.
1182+
Valid keys are:
1183+
- `v1`: Config file for Stable Diffusion v1
1184+
- `v2`: Config file for Stable Diffusion v2
1185+
- `xl`: Config file for Stable Diffusion XL
1186+
- `xl_refiner`: Config file for Stable Diffusion XL Refiner
11781187
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
11791188
"""
11801189

@@ -1396,14 +1405,14 @@ def download_from_original_stable_diffusion_ckpt(
13961405
else:
13971406
vae.load_state_dict(converted_vae_checkpoint)
13981407
elif vae is None:
1399-
vae = AutoencoderKL.from_pretrained(vae_path)
1408+
vae = AutoencoderKL.from_pretrained(vae_path, local_files_only=local_files_only)
14001409

14011410
if model_type == "FrozenOpenCLIPEmbedder":
14021411
config_name = "stabilityai/stable-diffusion-2"
14031412
config_kwargs = {"subfolder": "text_encoder"}
14041413

14051414
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
1406-
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
1415+
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only)
14071416

14081417
if stable_unclip is None:
14091418
if controlnet:
@@ -1455,12 +1464,12 @@ def download_from_original_stable_diffusion_ckpt(
14551464
elif stable_unclip == "txt2img":
14561465
if stable_unclip_prior is None or stable_unclip_prior == "karlo":
14571466
karlo_model = "kakaobrain/karlo-v1-alpha"
1458-
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
1467+
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior", local_files_only=local_files_only)
14591468

1460-
prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1461-
prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
1469+
prior_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
1470+
prior_text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
14621471

1463-
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
1472+
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler", local_files_only=local_files_only)
14641473
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
14651474
else:
14661475
raise NotImplementedError(f"unknown prior for stable unclip model: {stable_unclip_prior}")
@@ -1486,8 +1495,8 @@ def download_from_original_stable_diffusion_ckpt(
14861495
raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
14871496
elif model_type == "PaintByExample":
14881497
vision_model = convert_paint_by_example_checkpoint(checkpoint)
1489-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1490-
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
1498+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
1499+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only)
14911500
pipe = PaintByExamplePipeline(
14921501
vae=vae,
14931502
image_encoder=vision_model,
@@ -1500,11 +1509,11 @@ def download_from_original_stable_diffusion_ckpt(
15001509
text_model = convert_ldm_clip_checkpoint(
15011510
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
15021511
)
1503-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if tokenizer is None else tokenizer
1512+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) if tokenizer is None else tokenizer
15041513

15051514
if load_safety_checker:
1506-
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
1507-
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
1515+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only)
1516+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only)
15081517
else:
15091518
safety_checker = None
15101519
feature_extractor = None
@@ -1532,9 +1541,9 @@ def download_from_original_stable_diffusion_ckpt(
15321541
)
15331542
elif model_type in ["SDXL", "SDXL-Refiner"]:
15341543
if model_type == "SDXL":
1535-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
1544+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
15361545
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
1537-
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
1546+
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only)
15381547

15391548
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15401549
config_kwargs = {"projection_dim": 1280}
@@ -1555,7 +1564,7 @@ def download_from_original_stable_diffusion_ckpt(
15551564
else:
15561565
tokenizer = None
15571566
text_encoder = None
1558-
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
1567+
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only)
15591568

15601569
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
15611570
config_kwargs = {"projection_dim": 1280}
@@ -1577,7 +1586,7 @@ def download_from_original_stable_diffusion_ckpt(
15771586
else:
15781587
text_config = create_ldm_bert_config(original_config)
15791588
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
1580-
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
1589+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", local_files_only=local_files_only)
15811590
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
15821591

15831592
return pipe

0 commit comments

Comments
 (0)