Skip to content

Commit 299c16d

Browse files
authored
Fix loading Img2Img refiner components in from_single_file (huggingface#7282)
* update * update * update * update
1 parent 69f4919 commit 299c16d

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def build_sub_model_components(
5656

5757
if component_name == "unet":
5858
num_in_channels = kwargs.pop("num_in_channels", None)
59+
upcast_attention = kwargs.pop("upcast_attention", None)
60+
5961
unet_components = create_diffusers_unet_model_from_ldm(
6062
pipeline_class_name,
6163
original_config,
@@ -64,6 +66,7 @@ def build_sub_model_components(
6466
image_size=image_size,
6567
torch_dtype=torch_dtype,
6668
model_type=model_type,
69+
upcast_attention=upcast_attention,
6770
)
6871
return unet_components
6972

@@ -300,7 +303,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
300303
continue
301304
init_kwargs.update(components)
302305

303-
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
306+
additional_components = set_additional_components(
307+
class_name, original_config, checkpoint=checkpoint, model_type=model_type
308+
)
304309
if additional_components:
305310
init_kwargs.update(additional_components)
306311

src/diffusers/loaders/single_file_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def is_valid_url(url):
410410
return original_config
411411

412412

413-
def infer_model_type(original_config, checkpoint=None, model_type=None):
413+
def infer_model_type(original_config, checkpoint, model_type=None):
414414
if model_type is not None:
415415
return model_type
416416

@@ -1279,7 +1279,7 @@ def create_diffusers_unet_model_from_ldm(
12791279
original_config,
12801280
checkpoint,
12811281
num_in_channels=None,
1282-
upcast_attention=False,
1282+
upcast_attention=None,
12831283
extract_ema=False,
12841284
image_size=None,
12851285
torch_dtype=None,
@@ -1307,7 +1307,8 @@ def create_diffusers_unet_model_from_ldm(
13071307
)
13081308
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
13091309
unet_config["in_channels"] = num_in_channels
1310-
unet_config["upcast_attention"] = upcast_attention
1310+
if upcast_attention is not None:
1311+
unet_config["upcast_attention"] = upcast_attention
13111312

13121313
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
13131314
ctx = init_empty_weights if is_accelerate_available() else nullcontext

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,11 @@ def test_single_file_component_configs(self):
838838
for param_name, param_value in single_file_pipe.unet.config.items():
839839
if param_name in PARAMS_TO_IGNORE:
840840
continue
841+
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
842+
pipe.unet.config[param_name] = False
841843
assert (
842844
pipe.unet.config[param_name] == param_value
843-
), f"{param_name} differs between single file loading and pretrained loading"
845+
), f"{param_name} is differs between single file loading and pretrained loading"
844846

845847
for param_name, param_value in single_file_pipe.vae.config.items():
846848
if param_name in PARAMS_TO_IGNORE:

0 commit comments

Comments
 (0)