Skip to content

Commit 22b9cb0

Browse files
[From pretrained] Allow returning local path (huggingface#1450)
Allow returning local path
1 parent 25f850a commit 22b9cb0

File tree

2 files changed

+62
-28
lines changed

2 files changed

+62
-28
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
377377
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
378378
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
379379
setting this argument to `True` will raise an error.
380-
380+
return_cached_folder (`bool`, *optional*, defaults to `False`):
381+
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
381382
kwargs (remaining dictionary of keyword arguments, *optional*):
382383
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
383384
specific pipeline class. The overwritten components are then directly passed to the pipelines
@@ -430,33 +431,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
430431
sess_options = kwargs.pop("sess_options", None)
431432
device_map = kwargs.pop("device_map", None)
432433
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
433-
434-
if low_cpu_mem_usage and not is_accelerate_available():
435-
low_cpu_mem_usage = False
436-
logger.warning(
437-
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
438-
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
439-
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
440-
" install accelerate\n```\n."
441-
)
442-
443-
if device_map is not None and not is_torch_version(">=", "1.9.0"):
444-
raise NotImplementedError(
445-
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
446-
" `device_map=None`."
447-
)
448-
449-
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
450-
raise NotImplementedError(
451-
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
452-
" `low_cpu_mem_usage=False`."
453-
)
454-
455-
if low_cpu_mem_usage is False and device_map is not None:
456-
raise ValueError(
457-
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
458-
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
459-
)
434+
return_cached_folder = kwargs.pop("return_cached_folder", False)
460435

461436
# 1. Download the checkpoints and configs
462437
# use snapshot download here to get it working from from_pretrained
@@ -585,6 +560,33 @@ def load_module(name, value):
585560
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
586561
)
587562

563+
if low_cpu_mem_usage and not is_accelerate_available():
564+
low_cpu_mem_usage = False
565+
logger.warning(
566+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
567+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
568+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
569+
" install accelerate\n```\n."
570+
)
571+
572+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
573+
raise NotImplementedError(
574+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
575+
" `device_map=None`."
576+
)
577+
578+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
579+
raise NotImplementedError(
580+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
581+
" `low_cpu_mem_usage=False`."
582+
)
583+
584+
if low_cpu_mem_usage is False and device_map is not None:
585+
raise ValueError(
586+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
587+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
588+
)
589+
588590
# import it here to avoid circular import
589591
from diffusers import pipelines
590592

@@ -704,6 +706,9 @@ def load_module(name, value):
704706

705707
# 5. Instantiate the pipeline
706708
model = pipeline_class(**init_kwargs)
709+
710+
if return_cached_folder:
711+
return model, cached_folder
707712
return model
708713

709714
@staticmethod

tests/test_pipelines.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,35 @@ def test_download_only_pytorch(self):
9595
# We need to never convert this tiny model to safetensors for this test to pass
9696
assert not any(f.endswith(".safetensors") for f in files)
9797

98+
def test_returned_cached_folder(self):
99+
prompt = "hello"
100+
pipe = StableDiffusionPipeline.from_pretrained(
101+
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
102+
)
103+
_, local_path = StableDiffusionPipeline.from_pretrained(
104+
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, return_cached_folder=True
105+
)
106+
pipe_2 = StableDiffusionPipeline.from_pretrained(local_path)
107+
108+
pipe = pipe.to(torch_device)
109+
pipe_2 = pipe.to(torch_device)
110+
if torch_device == "mps":
111+
# device type MPS is not supported for torch.Generator() api.
112+
generator = torch.manual_seed(0)
113+
else:
114+
generator = torch.Generator(device=torch_device).manual_seed(0)
115+
116+
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
117+
118+
if torch_device == "mps":
119+
# device type MPS is not supported for torch.Generator() api.
120+
generator = torch.manual_seed(0)
121+
else:
122+
generator = torch.Generator(device=torch_device).manual_seed(0)
123+
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
124+
125+
assert np.max(np.abs(out - out_2)) < 1e-3
126+
98127
def test_download_safetensors(self):
99128
with tempfile.TemporaryDirectory() as tmpdirname:
100129
# pipeline has Flax weights

0 commit comments

Comments
 (0)