@@ -905,14 +905,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
905905
906906class FromOriginalUNetMixin :
907907 """
908- Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel `].
908+ Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet `].
909909 """
910910
911911 @classmethod
912912 @validate_hf_hub_args
913913 def from_single_file (cls , pretrained_model_link_or_path , ** kwargs ):
914914 r"""
915- Instantiate a [`ControlNetModel `] from pretrained ControlNet weights saved in the original `.ckpt` or
915+ Instantiate a [`StableCascadeUNet `] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
916916 `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
917917
918918 Parameters:
@@ -951,6 +951,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
951951 Can be used to overwrite load and saveable variables of the model.
952952
953953 """
954+ class_name = cls .__name__
955+ if class_name != "StableCascadeUNet" :
956+ raise ValueError ("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet" )
957+
954958 config = kwargs .pop ("config" , None )
955959 resume_download = kwargs .pop ("resume_download" , False )
956960 force_download = kwargs .pop ("force_download" , False )
@@ -961,10 +965,6 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
961965 revision = kwargs .pop ("revision" , None )
962966 torch_dtype = kwargs .pop ("torch_dtype" , None )
963967
964- class_name = cls .__name__
965- if class_name != "StableCascadeUNet" :
966- raise ValueError ("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet" )
967-
968968 checkpoint = load_single_file_model_checkpoint (
969969 pretrained_model_link_or_path ,
970970 resume_download = resume_download ,
0 commit comments