1515# limitations under the License.
1616
1717import importlib
18+ import inspect
1819import os
1920from typing import Optional , Union
2021
@@ -148,6 +149,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
148149 diffusers_module = importlib .import_module (cls .__module__ .split ("." )[0 ])
149150 pipeline_class = getattr (diffusers_module , config_dict ["_class_name" ])
150151
152+ # some modules can be passed directly to the init
153+ # in this case they are already instantiated in `kwargs`
154+ # extract them here
155+ expected_modules = set (inspect .signature (pipeline_class .__init__ ).parameters .keys ())
156+ passed_class_obj = {k : kwargs .pop (k ) for k in expected_modules if k in kwargs }
157+
151158 init_dict , _ = pipeline_class .extract_init_dict (config_dict , ** kwargs )
152159
153160 init_kwargs = {}
@@ -158,8 +165,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
158165 # 3. Load each module in the pipeline
159166 for name , (library_name , class_name ) in init_dict .items ():
160167 is_pipeline_module = hasattr (pipelines , library_name )
168+ loaded_sub_model = None
169+
161170 # if the model is in a pipeline module, then we load it from the pipeline
162- if is_pipeline_module :
171+ if name in passed_class_obj :
172+ # 1. check that passed_class_obj has correct parent class
173+ if not is_pipeline_module :
174+ library = importlib .import_module (library_name )
175+ class_obj = getattr (library , class_name )
176+ importable_classes = LOADABLE_CLASSES [library_name ]
177+ class_candidates = {c : getattr (library , c ) for c in importable_classes .keys ()}
178+
179+ expected_class_obj = None
180+ for class_name , class_candidate in class_candidates .items ():
181+ if issubclass (class_obj , class_candidate ):
182+ expected_class_obj = class_candidate
183+
184+ if not issubclass (passed_class_obj [name ].__class__ , expected_class_obj ):
185+ raise ValueError (
186+ f"{ passed_class_obj [name ]} is of type: { type (passed_class_obj [name ])} , but should be"
187+ f" { expected_class_obj } "
188+ )
189+ else :
190+ logger .warn (
191+ f"You have passed a non-standard module { passed_class_obj [name ]} . We cannot verify whether it"
192+ " has the correct type"
193+ )
194+
195+ # set passed class object
196+ loaded_sub_model = passed_class_obj [name ]
197+ elif is_pipeline_module :
163198 pipeline_module = getattr (pipelines , library_name )
164199 class_obj = getattr (pipeline_module , class_name )
165200 importable_classes = ALL_IMPORTABLE_CLASSES
@@ -171,23 +206,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
171206 importable_classes = LOADABLE_CLASSES [library_name ]
172207 class_candidates = {c : getattr (library , c ) for c in importable_classes .keys ()}
173208
174- load_method_name = None
175- for class_name , class_candidate in class_candidates .items ():
176- if issubclass (class_obj , class_candidate ):
177- load_method_name = importable_classes [class_name ][1 ]
209+ if loaded_sub_model is None :
210+ load_method_name = None
211+ for class_name , class_candidate in class_candidates .items ():
212+ if issubclass (class_obj , class_candidate ):
213+ load_method_name = importable_classes [class_name ][1 ]
178214
179- load_method = getattr (class_obj , load_method_name )
215+ load_method = getattr (class_obj , load_method_name )
180216
181- # check if the module is in a subdirectory
182- if os .path .isdir (os .path .join (cached_folder , name )):
183- loaded_sub_model = load_method (os .path .join (cached_folder , name ))
184- else :
185- # else load from the root directory
186- loaded_sub_model = load_method (cached_folder )
217+ # check if the module is in a subdirectory
218+ if os .path .isdir (os .path .join (cached_folder , name )):
219+ loaded_sub_model = load_method (os .path .join (cached_folder , name ))
220+ else :
221+ # else load from the root directory
222+ loaded_sub_model = load_method (cached_folder )
187223
188224 init_kwargs [name ] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
189225
190- # 5 . Instantiate the pipeline
226+ # 4 . Instantiate the pipeline
191227 model = pipeline_class (** init_kwargs )
192228 return model
193229
0 commit comments