1919import inspect
2020import os
2121import re
22+ import sys
2223import warnings
2324from dataclasses import dataclass
2425from pathlib import Path
@@ -540,11 +541,9 @@ def save_pretrained(
540541 variant (`str`, *optional*):
541542 If specified, weights are saved in the format pytorch_model.<variant>.bin.
542543 """
543- self .save_config (save_directory )
544-
545544 model_index_dict = dict (self .config )
546- model_index_dict .pop ("_class_name" )
547- model_index_dict .pop ("_diffusers_version" )
545+ model_index_dict .pop ("_class_name" , None )
546+ model_index_dict .pop ("_diffusers_version" , None )
548547 model_index_dict .pop ("_module" , None )
549548
550549 expected_modules , optional_kwargs = self ._get_signature_keys (self )
@@ -557,7 +556,6 @@ def is_saveable_module(name, value):
557556 return True
558557
559558 model_index_dict = {k : v for k , v in model_index_dict .items () if is_saveable_module (k , v )}
560-
561559 for pipeline_component_name in model_index_dict .keys ():
562560 sub_model = getattr (self , pipeline_component_name )
563561 model_cls = sub_model .__class__
@@ -571,7 +569,13 @@ def is_saveable_module(name, value):
571569 save_method_name = None
572570 # search for the model's base class in LOADABLE_CLASSES
573571 for library_name , library_classes in LOADABLE_CLASSES .items ():
574- library = importlib .import_module (library_name )
572+ if library_name in sys .modules :
573+ library = importlib .import_module (library_name )
574+ else :
575+ logger .info (
576+ f"{ library_name } is not installed. Cannot save { pipeline_component_name } as { library_classes } from { library_name } "
577+ )
578+
575579 for base_class , save_load_methods in library_classes .items ():
576580 class_candidate = getattr (library , base_class , None )
577581 if class_candidate is not None and issubclass (model_cls , class_candidate ):
@@ -581,6 +585,12 @@ def is_saveable_module(name, value):
581585 if save_method_name is not None :
582586 break
583587
588+ if save_method_name is None :
589+ logger .warn (f"self.{ pipeline_component_name } ={ sub_model } of type { type (sub_model )} cannot be saved." )
590+ # make sure that unsaveable components are not tried to be loaded afterward
591+ self .register_to_config (** {pipeline_component_name : (None , None )})
592+ continue
593+
584594 save_method = getattr (sub_model , save_method_name )
585595
586596 # Call the save method with the argument safe_serialization only if it's supported
@@ -596,6 +606,9 @@ def is_saveable_module(name, value):
596606
597607 save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
598608
609+ # finally save the config
610+ self .save_config (save_directory )
611+
599612 def to (
600613 self ,
601614 torch_device : Optional [Union [str , torch .device ]] = None ,
0 commit comments