Skip to content

Commit 86a2676

Browse files
authored
Correctly handle creating model index json files when setting compiled modules in pipelines. (huggingface#6436)
update
1 parent 6ef2b8a commit 86a2676

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,36 @@ def load_sub_model(
530530
return loaded_sub_model
531531

532532

533+
def _fetch_class_library_tuple(module):
534+
# import it here to avoid circular import
535+
diffusers_module = importlib.import_module(__name__.split(".")[0])
536+
pipelines = getattr(diffusers_module, "pipelines")
537+
538+
# register the config from the original module, not the dynamo compiled one
539+
not_compiled_module = _unwrap_model(module)
540+
library = not_compiled_module.__module__.split(".")[0]
541+
542+
# check if the module is a pipeline module
543+
module_path_items = not_compiled_module.__module__.split(".")
544+
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
545+
546+
path = not_compiled_module.__module__.split(".")
547+
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
548+
549+
# if library is not in LOADABLE_CLASSES, then it is a custom module.
550+
# Or if it's a pipeline module, then the module is inside the pipeline
551+
# folder so we set the library to module name.
552+
if is_pipeline_module:
553+
library = pipeline_dir
554+
elif library not in LOADABLE_CLASSES:
555+
library = not_compiled_module.__module__
556+
557+
# retrieve class_name
558+
class_name = not_compiled_module.__class__.__name__
559+
560+
return (library, class_name)
561+
562+
533563
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
534564
r"""
535565
Base class for all pipelines.
@@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
556586
_is_onnx = False
557587

558588
def register_modules(self, **kwargs):
559-
# import it here to avoid circular import
560-
diffusers_module = importlib.import_module(__name__.split(".")[0])
561-
pipelines = getattr(diffusers_module, "pipelines")
562-
563589
for name, module in kwargs.items():
564590
# retrieve library
565591
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
566592
register_dict = {name: (None, None)}
567593
else:
568-
# register the config from the original module, not the dynamo compiled one
569-
not_compiled_module = _unwrap_model(module)
570-
571-
library = not_compiled_module.__module__.split(".")[0]
572-
573-
# check if the module is a pipeline module
574-
module_path_items = not_compiled_module.__module__.split(".")
575-
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
576-
577-
path = not_compiled_module.__module__.split(".")
578-
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
579-
580-
# if library is not in LOADABLE_CLASSES, then it is a custom module.
581-
# Or if it's a pipeline module, then the module is inside the pipeline
582-
# folder so we set the library to module name.
583-
if is_pipeline_module:
584-
library = pipeline_dir
585-
elif library not in LOADABLE_CLASSES:
586-
library = not_compiled_module.__module__
587-
588-
# retrieve class_name
589-
class_name = not_compiled_module.__class__.__name__
590-
594+
library, class_name = _fetch_class_library_tuple(module)
591595
register_dict = {name: (library, class_name)}
592596

593597
# save model index config
@@ -601,7 +605,7 @@ def __setattr__(self, name: str, value: Any):
601605
# We need to overwrite the config if name exists in config
602606
if isinstance(getattr(self.config, name), (tuple, list)):
603607
if value is not None and self.config[name][0] is not None:
604-
class_library_tuple = (value.__module__.split(".")[0], value.__class__.__name__)
608+
class_library_tuple = _fetch_class_library_tuple(value)
605609
else:
606610
class_library_tuple = (None, None)
607611

0 commit comments

Comments
 (0)