@@ -530,6 +530,36 @@ def load_sub_model(
530530 return loaded_sub_model
531531
532532
533+ def _fetch_class_library_tuple (module ):
534+ # import it here to avoid circular import
535+ diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
536+ pipelines = getattr (diffusers_module , "pipelines" )
537+
538+ # register the config from the original module, not the dynamo compiled one
539+ not_compiled_module = _unwrap_model (module )
540+ library = not_compiled_module .__module__ .split ("." )[0 ]
541+
542+ # check if the module is a pipeline module
543+ module_path_items = not_compiled_module .__module__ .split ("." )
544+ pipeline_dir = module_path_items [- 2 ] if len (module_path_items ) > 2 else None
545+
546+ path = not_compiled_module .__module__ .split ("." )
547+ is_pipeline_module = pipeline_dir in path and hasattr (pipelines , pipeline_dir )
548+
549+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
550+ # Or if it's a pipeline module, then the module is inside the pipeline
551+ # folder so we set the library to module name.
552+ if is_pipeline_module :
553+ library = pipeline_dir
554+ elif library not in LOADABLE_CLASSES :
555+ library = not_compiled_module .__module__
556+
557+ # retrieve class_name
558+ class_name = not_compiled_module .__class__ .__name__
559+
560+ return (library , class_name )
561+
562+
533563class DiffusionPipeline (ConfigMixin , PushToHubMixin ):
534564 r"""
535565 Base class for all pipelines.
@@ -556,38 +586,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
556586 _is_onnx = False
557587
558588 def register_modules (self , ** kwargs ):
559- # import it here to avoid circular import
560- diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
561- pipelines = getattr (diffusers_module , "pipelines" )
562-
563589 for name , module in kwargs .items ():
564590 # retrieve library
565591 if module is None or isinstance (module , (tuple , list )) and module [0 ] is None :
566592 register_dict = {name : (None , None )}
567593 else :
568- # register the config from the original module, not the dynamo compiled one
569- not_compiled_module = _unwrap_model (module )
570-
571- library = not_compiled_module .__module__ .split ("." )[0 ]
572-
573- # check if the module is a pipeline module
574- module_path_items = not_compiled_module .__module__ .split ("." )
575- pipeline_dir = module_path_items [- 2 ] if len (module_path_items ) > 2 else None
576-
577- path = not_compiled_module .__module__ .split ("." )
578- is_pipeline_module = pipeline_dir in path and hasattr (pipelines , pipeline_dir )
579-
580- # if library is not in LOADABLE_CLASSES, then it is a custom module.
581- # Or if it's a pipeline module, then the module is inside the pipeline
582- # folder so we set the library to module name.
583- if is_pipeline_module :
584- library = pipeline_dir
585- elif library not in LOADABLE_CLASSES :
586- library = not_compiled_module .__module__
587-
588- # retrieve class_name
589- class_name = not_compiled_module .__class__ .__name__
590-
594+ library , class_name = _fetch_class_library_tuple (module )
591595 register_dict = {name : (library , class_name )}
592596
593597 # save model index config
@@ -601,7 +605,7 @@ def __setattr__(self, name: str, value: Any):
601605 # We need to overwrite the config if name exists in config
602606 if isinstance (getattr (self .config , name ), (tuple , list )):
603607 if value is not None and self .config [name ][0 ] is not None :
604- class_library_tuple = (value . __module__ . split ( "." )[ 0 ], value . __class__ . __name__ )
608+ class_library_tuple = _fetch_class_library_tuple (value )
605609 else :
606610 class_library_tuple = (None , None )
607611
0 commit comments