Skip to content

Commit 8efd9ce

Browse files
authored
[Chore] clean residue from copy-pasting in the UNet single file loader (huggingface#7295)
clean residue from copy-pasting
1 parent 299c16d commit 8efd9ce

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/diffusers/loaders/unet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,14 +905,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
905905

906906
class FromOriginalUNetMixin:
907907
"""
908-
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
908+
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
909909
"""
910910

911911
@classmethod
912912
@validate_hf_hub_args
913913
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
914914
r"""
915-
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
915+
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
916916
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
917917
918918
Parameters:
@@ -951,6 +951,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
951951
Can be used to overwrite load and saveable variables of the model.
952952
953953
"""
954+
class_name = cls.__name__
955+
if class_name != "StableCascadeUNet":
956+
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
957+
954958
config = kwargs.pop("config", None)
955959
resume_download = kwargs.pop("resume_download", False)
956960
force_download = kwargs.pop("force_download", False)
@@ -961,10 +965,6 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
961965
revision = kwargs.pop("revision", None)
962966
torch_dtype = kwargs.pop("torch_dtype", None)
963967

964-
class_name = cls.__name__
965-
if class_name != "StableCascadeUNet":
966-
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
967-
968968
checkpoint = load_single_file_model_checkpoint(
969969
pretrained_model_link_or_path,
970970
resume_download=resume_download,

0 commit comments

Comments
 (0)