@@ -146,6 +146,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
146146 local_files_only = kwargs .pop ("local_files_only" , False )
147147 use_auth_token = kwargs .pop ("use_auth_token" , None )
148148 revision = kwargs .pop ("revision" , None )
149+ torch_dtype = kwargs .pop ("torch_dtype" , None )
149150
150151 # 1. Download the checkpoints and configs
151152 # use snapshot download here to get it working from from_pretrained
@@ -237,12 +238,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
237238
238239 load_method = getattr (class_obj , load_method_name )
239240
241+ loading_kwargs = {}
242+ if issubclass (class_obj , torch .nn .Module ):
243+ loading_kwargs ["torch_dtype" ] = torch_dtype
244+
240245 # check if the module is in a subdirectory
241246 if os .path .isdir (os .path .join (cached_folder , name )):
242- loaded_sub_model = load_method (os .path .join (cached_folder , name ))
247+ loaded_sub_model = load_method (os .path .join (cached_folder , name ), ** loading_kwargs )
243248 else :
244249 # else load from the root directory
245- loaded_sub_model = load_method (cached_folder )
250+ loaded_sub_model = load_method (cached_folder , ** loading_kwargs )
246251
247252 init_kwargs [name ] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
248253
0 commit comments