Skip to content

Commit 080ecf0

Browse files
Improve loading pipe (huggingface#4009)
* improve loading subcomponents * Add test for logging * improve loading subcomponents * make style * make style * fix * finish
1 parent 7a91ea6 commit 080ecf0

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
556556
model_index_dict.pop("_class_name", None)
557557
model_index_dict.pop("_diffusers_version", None)
558558
model_index_dict.pop("_module", None)
559+
model_index_dict.pop("_name_or_path", None)
559560

560561
expected_modules, optional_kwargs = self._get_signature_keys(self)
561562

@@ -1013,7 +1014,7 @@ def load_module(name, value):
10131014
from diffusers import pipelines
10141015

10151016
# 6. Load each module in the pipeline
1016-
for name, (library_name, class_name) in init_dict.items():
1017+
for name, (library_name, class_name) in tqdm(init_dict.items(), desc="Loading pipeline components..."):
10171018
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
10181019
if class_name.startswith("Flax"):
10191020
class_name = class_name[4:]
@@ -1055,6 +1056,9 @@ def load_module(name, value):
10551056
low_cpu_mem_usage=low_cpu_mem_usage,
10561057
cached_folder=cached_folder,
10571058
)
1059+
logger.info(
1060+
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
1061+
)
10581062

10591063
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
10601064

@@ -1073,8 +1077,15 @@ def load_module(name, value):
10731077

10741078
# 8. Instantiate the pipeline
10751079
model = pipeline_class(**init_kwargs)
1080+
1081+
# 9. Save where the model was instantiated from
1082+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
10761083
return model
10771084

1085+
@property
1086+
def name_or_path(self) -> str:
1087+
return getattr(self.config, "_name_or_path", None)
1088+
10781089
@classmethod
10791090
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
10801091
r"""

tests/pipelines/test_pipelines.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,12 @@ def test_custom_model_and_pipeline(self):
821821
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
822822
pipe_new.save_pretrained(tmpdirname)
823823

824-
assert dict(pipe_new.config) == dict(pipe.config)
824+
conf_1 = dict(pipe.config)
825+
conf_2 = dict(pipe_new.config)
826+
827+
del conf_2["_name_or_path"]
828+
829+
assert conf_1 == conf_2
825830

826831
@slow
827832
@require_torch_gpu
@@ -1363,6 +1368,18 @@ def test_optional_components(self):
13631368
assert sd.config.safety_checker != (None, None)
13641369
assert sd.config.feature_extractor != (None, None)
13651370

1371+
def test_name_or_path(self):
1372+
model_path = "hf-internal-testing/tiny-stable-diffusion-torch"
1373+
sd = DiffusionPipeline.from_pretrained(model_path)
1374+
1375+
assert sd.name_or_path == model_path
1376+
1377+
with tempfile.TemporaryDirectory() as tmpdirname:
1378+
sd.save_pretrained(tmpdirname)
1379+
sd = DiffusionPipeline.from_pretrained(tmpdirname)
1380+
1381+
assert sd.name_or_path == tmpdirname
1382+
13661383
def test_warning_no_variant_available(self):
13671384
variant = "fp16"
13681385
with self.assertWarns(FutureWarning) as warning_context:

tests/pipelines/test_pipelines_common.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from diffusers.schedulers import KarrasDiffusionSchedulers
1818
from diffusers.utils import logging
1919
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
20-
from diffusers.utils.testing_utils import require_torch, torch_device
20+
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
2121

2222

2323
def to_np(tensor):
@@ -298,9 +298,19 @@ def test_save_load_local(self, expected_max_difference=1e-4):
298298
inputs = self.get_dummy_inputs(torch_device)
299299
output = pipe(**inputs)[0]
300300

301+
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
302+
logger.setLevel(diffusers.logging.INFO)
303+
301304
with tempfile.TemporaryDirectory() as tmpdir:
302305
pipe.save_pretrained(tmpdir)
303-
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
306+
307+
with CaptureLogger(logger) as cap_logger:
308+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
309+
310+
for name in pipe_loaded.components.keys():
311+
if name not in pipe_loaded._optional_components:
312+
assert name in str(cap_logger)
313+
304314
pipe_loaded.to(torch_device)
305315
pipe_loaded.set_progress_bar_config(disable=None)
306316

0 commit comments

Comments
 (0)