Skip to content

Commit b7af946

Browse files
set config from original module but set compiled module on class (huggingface#3650)
* set config from original module but set compiled module on class * add test
1 parent d3717e6 commit b7af946

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -485,17 +485,19 @@ def register_modules(self, **kwargs):
485485
if module is None:
486486
register_dict = {name: (None, None)}
487487
else:
488-
# register the original module, not the dynamo compiled one
488+
# register the config from the original module, not the dynamo compiled one
489489
if is_compiled_module(module):
490-
module = module._orig_mod
490+
not_compiled_module = module._orig_mod
491+
else:
492+
not_compiled_module = module
491493

492-
library = module.__module__.split(".")[0]
494+
library = not_compiled_module.__module__.split(".")[0]
493495

494496
# check if the module is a pipeline module
495-
module_path_items = module.__module__.split(".")
497+
module_path_items = not_compiled_module.__module__.split(".")
496498
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
497499

498-
path = module.__module__.split(".")
500+
path = not_compiled_module.__module__.split(".")
499501
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
500502

501503
# if library is not in LOADABLE_CLASSES, then it is a custom module.
@@ -504,10 +506,10 @@ def register_modules(self, **kwargs):
504506
if is_pipeline_module:
505507
library = pipeline_dir
506508
elif library not in LOADABLE_CLASSES:
507-
library = module.__module__
509+
library = not_compiled_module.__module__
508510

509511
# retrieve class_name
510-
class_name = module.__class__.__name__
512+
class_name = not_compiled_module.__class__.__name__
511513

512514
register_dict = {name: (library, class_name)}
513515

tests/pipelines/test_pipelines.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
CONFIG_NAME,
6262
WEIGHTS_NAME,
6363
floats_tensor,
64+
is_compiled_module,
6465
nightly,
6566
require_torch_2,
6667
slow,
@@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
99100
scheduler = DDPMScheduler(num_train_timesteps=10)
100101

101102
ddpm = DDPMPipeline(model, scheduler)
103+
104+
# previous diffusers versions stripped compilation off
105+
# compiled modules
106+
assert is_compiled_module(ddpm.unet)
107+
102108
ddpm.to(torch_device)
103109
ddpm.set_progress_bar_config(disable=None)
104110

0 commit comments

Comments
 (0)